diff --git a/core/Moonshark.go b/core/Moonshark.go
index 752c238..d4a103f 100644
--- a/core/Moonshark.go
+++ b/core/Moonshark.go
@@ -185,7 +185,7 @@ func (s *Moonshark) initRunner() error {
runner.WithPoolSize(s.Config.Runner.PoolSize),
runner.WithLibDirs(s.Config.Dirs.Libs...),
runner.WithSessionManager(sessionManager),
- runner.WithCSRFProtection(),
+ http.WithCSRFProtection(),
}
// Add debug option conditionally
diff --git a/core/http/Csrf.go b/core/http/Csrf.go
index 1830b78..7347ad4 100644
--- a/core/http/Csrf.go
+++ b/core/http/Csrf.go
@@ -1,12 +1,119 @@
package http
import (
+ "Moonshark/core/runner"
"Moonshark/core/utils"
"Moonshark/core/utils/logger"
+ "crypto/subtle"
"github.com/valyala/fasthttp"
+
+ luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
+// ValidateCSRFToken checks if the CSRF token is valid for a request
+func ValidateCSRFToken(state *luajit.State, ctx *runner.Context) bool {
+ // Only validate for form submissions
+ method, ok := ctx.Get("method").(string)
+ if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") {
+ return true
+ }
+
+ // Get form data
+ formData, ok := ctx.Get("form").(map[string]any)
+ if !ok || formData == nil {
+ logger.Warning("CSRF validation failed: no form data")
+ return false
+ }
+
+ // Get token from form
+ formToken, ok := formData["csrf"].(string)
+ if !ok || formToken == "" {
+ logger.Warning("CSRF validation failed: no token in form")
+ return false
+ }
+
+ // Get session token
+ state.GetGlobal("session")
+ if state.IsNil(-1) {
+ state.Pop(1)
+ logger.Warning("CSRF validation failed: session module not available")
+ return false
+ }
+
+ state.GetField(-1, "get")
+ if !state.IsFunction(-1) {
+ state.Pop(2)
+ logger.Warning("CSRF validation failed: session.get not available")
+ return false
+ }
+
+ state.PushCopy(-1) // Duplicate function
+ state.PushString("_csrf_token")
+
+ if err := state.Call(1, 1); err != nil {
+ state.Pop(3) // Pop error, function and session table
+ logger.Warning("CSRF validation failed: %v", err)
+ return false
+ }
+
+ if state.IsNil(-1) {
+ state.Pop(3) // Pop nil, function and session table
+ logger.Warning("CSRF validation failed: no token in session")
+ return false
+ }
+
+ sessionToken := state.ToString(-1)
+ state.Pop(3) // Pop token, function and session table
+
+ // Constant-time comparison to prevent timing attacks
+ return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1
+}
+
+// WithCSRFProtection creates a runner option to add CSRF protection
+func WithCSRFProtection() runner.RunnerOption {
+ return func(r *runner.Runner) {
+ r.AddInitHook(func(state *luajit.State, ctx *runner.Context) error {
+ // Get request method
+ method, ok := ctx.Get("method").(string)
+ if !ok {
+ return nil
+ }
+
+ // Only validate for form submissions
+ if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" {
+ return nil
+ }
+
+ // Check for form data
+ form, ok := ctx.Get("form").(map[string]any)
+ if !ok || form == nil {
+ return nil
+ }
+
+ // Validate CSRF token
+ if !ValidateCSRFToken(state, ctx) {
+ return ErrCSRFValidationFailed
+ }
+
+ return nil
+ })
+ }
+}
+
+// Error for CSRF validation failure
+var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"}
+
+// CSRFError represents a CSRF validation error
+type CSRFError struct {
+ message string
+}
+
+// Error implements the error interface
+func (e *CSRFError) Error() string {
+ return e.message
+}
+
// HandleCSRFError handles a CSRF validation error
func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) {
method := string(ctx.Method())
diff --git a/core/http/Server.go b/core/http/Server.go
index 6b87c0a..fa23a49 100644
--- a/core/http/Server.go
+++ b/core/http/Server.go
@@ -8,6 +8,7 @@ import (
"Moonshark/core/metadata"
"Moonshark/core/routers"
"Moonshark/core/runner"
+ "Moonshark/core/runner/sandbox"
"Moonshark/core/utils"
"Moonshark/core/utils/config"
"Moonshark/core/utils/logger"
@@ -226,7 +227,7 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
// Special handling for CSRF error
if err != nil {
- if csrfErr, ok := err.(*runner.CSRFError); ok {
+ if csrfErr, ok := err.(*CSRFError); ok {
logger.Warning("CSRF error executing Lua route: %v", csrfErr)
HandleCSRFError(ctx, s.errorConfig)
return
@@ -258,8 +259,8 @@ func writeResponse(ctx *fasthttp.RequestCtx, result any) {
}
// Check for HTTPResponse type
- if httpResp, ok := result.(*runner.HTTPResponse); ok {
- defer runner.ReleaseResponse(httpResp)
+ if httpResp, ok := result.(*sandbox.HTTPResponse); ok {
+ defer sandbox.ReleaseResponse(httpResp)
// Set response headers
for name, value := range httpResp.Headers {
diff --git a/core/runner/Context.go b/core/runner/Context.go
index 673c097..9586030 100644
--- a/core/runner/Context.go
+++ b/core/runner/Context.go
@@ -3,6 +3,8 @@ package runner
import (
"sync"
+ "maps"
+
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
)
@@ -24,7 +26,7 @@ type Context struct {
// Context pool to reduce allocations
var contextPool = sync.Pool{
- New: func() interface{} {
+ New: func() any {
return &Context{
Values: make(map[string]any, 16), // Pre-allocate with reasonable capacity
}
@@ -115,9 +117,7 @@ func (c *Context) All() map[string]any {
defer c.mu.RUnlock()
result := make(map[string]any, len(c.Values))
- for k, v := range c.Values {
- result[k] = v
- }
+ maps.Copy(result, c.Values)
return result
}
diff --git a/core/runner/CookieModule.go b/core/runner/CookieModule.go
deleted file mode 100644
index 356217b..0000000
--- a/core/runner/CookieModule.go
+++ /dev/null
@@ -1,187 +0,0 @@
-package runner
-
-import (
- "time"
-
- luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
- "github.com/valyala/fasthttp"
-)
-
-// extractCookie grabs cookies from the Lua state
-func extractCookie(state *luajit.State) *fasthttp.Cookie {
- cookie := fasthttp.AcquireCookie()
-
- // Get name
- state.GetField(-1, "name")
- if !state.IsString(-1) {
- state.Pop(1)
- fasthttp.ReleaseCookie(cookie)
- return nil // Name is required
- }
- cookie.SetKey(state.ToString(-1))
- state.Pop(1)
-
- // Get value
- state.GetField(-1, "value")
- if state.IsString(-1) {
- cookie.SetValue(state.ToString(-1))
- }
- state.Pop(1)
-
- // Get path
- state.GetField(-1, "path")
- if state.IsString(-1) {
- cookie.SetPath(state.ToString(-1))
- } else {
- cookie.SetPath("/") // Default path
- }
- state.Pop(1)
-
- // Get domain
- state.GetField(-1, "domain")
- if state.IsString(-1) {
- cookie.SetDomain(state.ToString(-1))
- }
- state.Pop(1)
-
- // Get expires
- state.GetField(-1, "expires")
- if state.IsNumber(-1) {
- expiry := int64(state.ToNumber(-1))
- cookie.SetExpire(time.Unix(expiry, 0))
- }
- state.Pop(1)
-
- // Get max age
- state.GetField(-1, "max_age")
- if state.IsNumber(-1) {
- cookie.SetMaxAge(int(state.ToNumber(-1)))
- }
- state.Pop(1)
-
- // Get secure
- state.GetField(-1, "secure")
- if state.IsBoolean(-1) {
- cookie.SetSecure(state.ToBoolean(-1))
- }
- state.Pop(1)
-
- // Get http only
- state.GetField(-1, "http_only")
- if state.IsBoolean(-1) {
- cookie.SetHTTPOnly(state.ToBoolean(-1))
- }
- state.Pop(1)
-
- return cookie
-}
-
-// LuaCookieModule provides cookie functionality to Lua scripts
-const LuaCookieModule = `
--- Cookie module implementation
-local cookie = {
- -- Set a cookie
- set = function(name, value, options, ...)
- if type(name) ~= "string" then
- error("cookie.set: name must be a string", 2)
- end
-
- -- Get or create response
- local resp = __http_responses[1] or {}
- resp.cookies = resp.cookies or {}
- __http_responses[1] = resp
-
- -- Handle options as table or legacy params
- local opts = {}
- if type(options) == "table" then
- opts = options
- elseif options ~= nil then
- -- Legacy support: options is actually 'expires'
- opts.expires = options
- -- Check for other legacy params (4th-7th args)
- local args = {...}
- if args[1] then opts.path = args[1] end
- if args[2] then opts.domain = args[2] end
- if args[3] then opts.secure = args[3] end
- if args[4] ~= nil then opts.http_only = args[4] end
- end
-
- -- Create cookie table
- local cookie = {
- name = name,
- value = value or "",
- path = opts.path or "/",
- domain = opts.domain
- }
-
- -- Handle expiry
- if opts.expires then
- if type(opts.expires) == "number" then
- if opts.expires > 0 then
- cookie.max_age = opts.expires
- local now = os.time()
- cookie.expires = now + opts.expires
- elseif opts.expires < 0 then
- cookie.expires = 1
- cookie.max_age = 0
- else
- -- opts.expires == 0: Session cookie
- -- Do nothing (omitting both expires and max-age creates a session cookie)
- end
- end
- end
-
- -- Security flags
- cookie.secure = (opts.secure ~= false)
- cookie.http_only = (opts.http_only ~= false)
-
- -- Store in cookies table
- local n = #resp.cookies + 1
- resp.cookies[n] = cookie
-
- return true
- end,
-
- -- Get a cookie value
- get = function(name)
- if type(name) ~= "string" then
- error("cookie.get: name must be a string", 2)
- end
-
- -- Access values directly from current environment
- local env = getfenv(2)
-
- -- Check if context exists and has cookies
- if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then
- return tostring(env.ctx.cookies[name])
- end
-
- return nil
- end,
-
- -- Remove a cookie
- remove = function(name, path, domain)
- if type(name) ~= "string" then
- error("cookie.remove: name must be a string", 2)
- end
-
- -- Create an expired cookie
- return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain})
- end
-}
-
--- Install cookie module
-_G.cookie = cookie
-
--- Make sure the cookie module is accessible in sandbox
-if __env_system and __env_system.base_env then
- __env_system.base_env.cookie = cookie
-end
-`
-
-// CookieModuleInitFunc returns an initializer for the cookie module
-func CookieModuleInitFunc() StateInitFunc {
- return func(state *luajit.State) error {
- return state.DoString(LuaCookieModule)
- }
-}
diff --git a/core/runner/Cookies.go b/core/runner/Cookies.go
new file mode 100644
index 0000000..b1258e4
--- /dev/null
+++ b/core/runner/Cookies.go
@@ -0,0 +1,77 @@
+package runner
+
+import (
+ "time"
+
+ luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
+ "github.com/valyala/fasthttp"
+)
+
+// extractCookie grabs cookies from the Lua state
+func extractCookie(state *luajit.State) *fasthttp.Cookie {
+ cookie := fasthttp.AcquireCookie()
+
+ // Get name
+ state.GetField(-1, "name")
+ if !state.IsString(-1) {
+ state.Pop(1)
+ fasthttp.ReleaseCookie(cookie)
+ return nil // Name is required
+ }
+ cookie.SetKey(state.ToString(-1))
+ state.Pop(1)
+
+ // Get value
+ state.GetField(-1, "value")
+ if state.IsString(-1) {
+ cookie.SetValue(state.ToString(-1))
+ }
+ state.Pop(1)
+
+ // Get path
+ state.GetField(-1, "path")
+ if state.IsString(-1) {
+ cookie.SetPath(state.ToString(-1))
+ } else {
+ cookie.SetPath("/") // Default path
+ }
+ state.Pop(1)
+
+ // Get domain
+ state.GetField(-1, "domain")
+ if state.IsString(-1) {
+ cookie.SetDomain(state.ToString(-1))
+ }
+ state.Pop(1)
+
+ // Get expires
+ state.GetField(-1, "expires")
+ if state.IsNumber(-1) {
+ expiry := int64(state.ToNumber(-1))
+ cookie.SetExpire(time.Unix(expiry, 0))
+ }
+ state.Pop(1)
+
+ // Get max age
+ state.GetField(-1, "max_age")
+ if state.IsNumber(-1) {
+ cookie.SetMaxAge(int(state.ToNumber(-1)))
+ }
+ state.Pop(1)
+
+ // Get secure
+ state.GetField(-1, "secure")
+ if state.IsBoolean(-1) {
+ cookie.SetSecure(state.ToBoolean(-1))
+ }
+ state.Pop(1)
+
+ // Get http only
+ state.GetField(-1, "http_only")
+ if state.IsBoolean(-1) {
+ cookie.SetHTTPOnly(state.ToBoolean(-1))
+ }
+ state.Pop(1)
+
+ return cookie
+}
diff --git a/core/runner/CoreModules.go b/core/runner/CoreModules.go
index bf25778..98f18dc 100644
--- a/core/runner/CoreModules.go
+++ b/core/runner/CoreModules.go
@@ -1,6 +1,7 @@
package runner
import (
+ "Moonshark/core/runner/sandbox"
"Moonshark/core/utils/logger"
"fmt"
"strings"
@@ -265,17 +266,21 @@ func init() {
GlobalRegistry.EnableDebug() // Enable debugging by default
logger.Debug("[ModuleRegistry] Registering core modules...")
- GlobalRegistry.Register("util", UtilModuleInitFunc())
- GlobalRegistry.Register("http", HTTPModuleInitFunc())
- GlobalRegistry.RegisterWithDependencies("cookie", CookieModuleInitFunc(), []string{"http"})
- GlobalRegistry.RegisterWithDependencies("csrf", CSRFModuleInitFunc(), []string{"util"})
+ // Register core modules - these now point to the sandbox implementations
+ GlobalRegistry.Register("util", func(state *luajit.State) error {
+ return sandbox.UtilModuleInitFunc()(state)
+ })
+
+ GlobalRegistry.Register("http", func(state *luajit.State) error {
+ return sandbox.HTTPModuleInitFunc()(state)
+ })
// Set explicit initialization order
GlobalRegistry.SetInitOrder([]string{
- "util", // First: core utilities
- "http", // Second: HTTP functionality
- "cookie", // Third: Cookie functionality (uses HTTP)
- "csrf", // Fourth: CSRF protection (uses go and possibly session)
+ "util", // First: core utilities
+ "http", // Second: HTTP functionality
+ "session", // Third: Session functionality
+ "csrf", // Fourth: CSRF protection
})
logger.DebugCont("Core modules registered successfully")
diff --git a/core/runner/CsrfModule.go b/core/runner/CsrfModule.go
deleted file mode 100644
index 683f35f..0000000
--- a/core/runner/CsrfModule.go
+++ /dev/null
@@ -1,219 +0,0 @@
-package runner
-
-import (
- "crypto/subtle"
-
- "Moonshark/core/utils/logger"
-
- luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
-)
-
-// LuaCSRFModule provides CSRF protection functionality to Lua scripts
-const LuaCSRFModule = `
--- CSRF protection module
-local csrf = {
- -- Session key where the token is stored
- TOKEN_KEY = "_csrf_token",
-
- -- Default form field name
- DEFAULT_FIELD = "csrf",
-
- -- Generate a new CSRF token and store it in the session
- generate = function(length)
- -- Default length is 32 characters
- length = length or 32
-
- if length < 16 then
- -- Enforce minimum security
- length = 16
- end
-
- -- Check if we have a session module
- if not session then
- error("CSRF protection requires the session module", 2)
- end
-
- local token = util.generate_token(length)
- session.set(csrf.TOKEN_KEY, token)
- return token
- end,
-
- -- Get the current token or generate a new one
- token = function()
- -- Get from session if exists
- local token = session.get(csrf.TOKEN_KEY)
-
- -- Generate if needed
- if not token then
- token = csrf.generate()
- end
-
- return token
- end,
-
- -- Generate a hidden form field with the CSRF token
- field = function(field_name)
- field_name = field_name or csrf.DEFAULT_FIELD
- local token = csrf.token()
- return string.format('', field_name, token)
- end,
-
- -- Verify a given token against the session token
- verify = function(token, field_name)
- field_name = field_name or csrf.DEFAULT_FIELD
-
- local env = getfenv(2)
-
- local form = nil
- if env.ctx and env.ctx.form then
- form = env.ctx.form
- else
- return false
- end
-
- token = token or form[field_name]
- if not token then
- return false
- end
-
- local session_token = session.get(csrf.TOKEN_KEY)
- if not session_token then
- return false
- end
-
- -- Constant-time comparison to prevent timing attacks
- -- This is safe since Lua strings are immutable
- if #token ~= #session_token then
- return false
- end
-
- local result = true
- for i = 1, #token do
- if token:sub(i, i) ~= session_token:sub(i, i) then
- result = false
- -- Don't break early - continue to prevent timing attacks
- end
- end
-
- return result
- end
-}
-
--- Install CSRF module
-_G.csrf = csrf
-
--- Make sure the CSRF module is accessible in sandbox
-if __env_system and __env_system.base_env then
- __env_system.base_env.csrf = csrf
-end
-`
-
-// CSRFModuleInitFunc returns an initializer for the CSRF module
-func CSRFModuleInitFunc() StateInitFunc {
- return func(state *luajit.State) error {
- return state.DoString(LuaCSRFModule)
- }
-}
-
-// ValidateCSRFToken checks if the CSRF token is valid for a request
-func ValidateCSRFToken(state *luajit.State, ctx *Context) bool {
- // Only validate for form submissions
- method, ok := ctx.Get("method").(string)
- if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") {
- return true
- }
-
- // Get form data
- formData, ok := ctx.Get("form").(map[string]any)
- if !ok || formData == nil {
- logger.Warning("CSRF validation failed: no form data")
- return false
- }
-
- // Get token from form
- formToken, ok := formData["csrf"].(string)
- if !ok || formToken == "" {
- logger.Warning("CSRF validation failed: no token in form")
- return false
- }
-
- // Get session token
- state.GetGlobal("session")
- if state.IsNil(-1) {
- state.Pop(1)
- logger.Warning("CSRF validation failed: session module not available")
- return false
- }
-
- state.GetField(-1, "get")
- if !state.IsFunction(-1) {
- state.Pop(2)
- logger.Warning("CSRF validation failed: session.get not available")
- return false
- }
-
- state.PushCopy(-1) // Duplicate function
- state.PushString("_csrf_token")
-
- if err := state.Call(1, 1); err != nil {
- state.Pop(3) // Pop error, function and session table
- logger.Warning("CSRF validation failed: %v", err)
- return false
- }
-
- if state.IsNil(-1) {
- state.Pop(3) // Pop nil, function and session table
- logger.Warning("CSRF validation failed: no token in session")
- return false
- }
-
- sessionToken := state.ToString(-1)
- state.Pop(3) // Pop token, function and session table
-
- // Constant-time comparison to prevent timing attacks
- return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1
-}
-
-// WithCSRFProtection creates a runner option to add CSRF protection
-func WithCSRFProtection() RunnerOption {
- return func(r *Runner) {
- r.AddInitHook(func(state *luajit.State, ctx *Context) error {
- // Get request method
- method, ok := ctx.Get("method").(string)
- if !ok {
- return nil
- }
-
- // Only validate for form submissions
- if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" {
- return nil
- }
-
- // Check for form data
- form, ok := ctx.Get("form").(map[string]any)
- if !ok || form == nil {
- return nil
- }
-
- // Validate CSRF token
- if !ValidateCSRFToken(state, ctx) {
- return ErrCSRFValidationFailed
- }
-
- return nil
- })
- }
-}
-
-// Error for CSRF validation failure
-var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"}
-
-// CSRFError represents a CSRF validation error
-type CSRFError struct {
- message string
-}
-
-// Error implements the error interface
-func (e *CSRFError) Error() string {
- return e.message
-}
diff --git a/core/runner/Runner.go b/core/runner/Runner.go
index 6455a4b..de6a647 100644
--- a/core/runner/Runner.go
+++ b/core/runner/Runner.go
@@ -12,6 +12,7 @@ import (
"github.com/panjf2000/ants/v2"
"github.com/valyala/bytebufferpool"
+ "Moonshark/core/runner/sandbox"
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
@@ -30,11 +31,11 @@ type RunnerOption func(*Runner)
// State wraps a Lua state with its sandbox
type State struct {
- L *luajit.State // The Lua state
- sandbox *Sandbox // Associated sandbox
- index int // Index for debugging
- inUse bool // Whether the state is currently in use
- initTime time.Time // When this state was initialized
+ L *luajit.State // The Lua state
+ sandbox *sandbox.Sandbox // Associated sandbox
+ index int // Index for debugging
+ inUse bool // Whether the state is currently in use
+ initTime time.Time // When this state was initialized
}
// InitHook runs before executing a script
@@ -217,7 +218,7 @@ func (r *Runner) createState(index int) (*State, error) {
}
// Create sandbox
- sandbox := NewSandbox()
+ sandbox := sandbox.NewSandbox()
if r.debug && verbose {
sandbox.EnableDebug()
}
diff --git a/core/runner/SessionModule.go b/core/runner/SessionModule.go
deleted file mode 100644
index bdea3a7..0000000
--- a/core/runner/SessionModule.go
+++ /dev/null
@@ -1,177 +0,0 @@
-package runner
-
-import (
- "Moonshark/core/utils/logger"
-
- luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
-)
-
-// LuaSessionModule provides session functionality to Lua scripts
-const LuaSessionModule = `
--- Global table to store session data
-__session_data = __session_data or {}
-__session_id = __session_id or nil
-__session_modified = false
-
--- Session module implementation
-local session = {
- -- Get a session value
- get = function(key)
- if type(key) ~= "string" then
- error("session.get: key must be a string", 2)
- end
-
- if __session_data and __session_data[key] then
- return __session_data[key]
- end
-
- return nil
- end,
-
- -- Set a session value
- set = function(key, value)
- if type(key) ~= "string" then
- error("session.set: key must be a string", 2)
- end
-
- -- Ensure session data table exists
- __session_data = __session_data or {}
-
- -- Store value
- __session_data[key] = value
-
- -- Mark session as modified
- __session_modified = true
-
- return true
- end,
-
- -- Delete a session value
- delete = function(key)
- if type(key) ~= "string" then
- error("session.delete: key must be a string", 2)
- end
-
- if __session_data then
- __session_data[key] = nil
- __session_modified = true
- end
-
- return true
- end,
-
- -- Clear all session data
- clear = function()
- __session_data = {}
- __session_modified = true
- return true
- end,
-
- -- Get the session ID
- get_id = function()
- return __session_id or nil
- end,
-
- -- Get all session data
- get_all = function()
- local result = {}
- for k, v in pairs(__session_data or {}) do
- result[k] = v
- end
- return result
- end,
-
- -- Check if session has a key
- has = function(key)
- if type(key) ~= "string" then
- error("session.has: key must be a string", 2)
- end
-
- return __session_data and __session_data[key] ~= nil
- end
-}
-
--- Install session module
-_G.session = session
-
--- Make sure the session module is accessible in sandbox
-if __env_system and __env_system.base_env then
- __env_system.base_env.session = session
-end
-
--- Hook into script execution to preserve session state
-local old_execute_script = __execute_script
-if old_execute_script then
- __execute_script = function(fn, ctx)
- -- Reset modification flag at the start of request
- __session_modified = false
-
- -- Execute original function
- return old_execute_script(fn, ctx)
- end
-end
-`
-
-// GetSessionData extracts session data from Lua state
-func GetSessionData(state *luajit.State) (string, map[string]any, bool) {
- // Check if session was modified
- state.GetGlobal("__session_modified")
- modified := state.ToBoolean(-1)
- state.Pop(1)
-
- if !modified {
- return "", nil, false
- }
-
- // Get session ID
- state.GetGlobal("__session_id")
- sessionID := state.ToString(-1)
- state.Pop(1)
-
- // Get session data
- state.GetGlobal("__session_data")
- if !state.IsTable(-1) {
- state.Pop(1)
- return sessionID, nil, false
- }
-
- data, err := state.ToTable(-1)
- state.Pop(1)
-
- if err != nil {
- logger.Error("Failed to extract session data: %v", err)
- return sessionID, nil, false
- }
-
- return sessionID, data, true
-}
-
-// SetSessionData sets session data in Lua state
-func SetSessionData(state *luajit.State, sessionID string, data map[string]any) error {
- // Set session ID
- state.PushString(sessionID)
- state.SetGlobal("__session_id")
-
- // Set session data
- if data == nil {
- data = make(map[string]any)
- }
-
- if err := state.PushTable(data); err != nil {
- return err
- }
- state.SetGlobal("__session_data")
-
- // Reset modification flag
- state.PushBoolean(false)
- state.SetGlobal("__session_modified")
-
- return nil
-}
-
-// SessionModuleInitFunc returns an initializer for the session module
-func SessionModuleInitFunc() StateInitFunc {
- return func(state *luajit.State) error {
- return state.DoString(LuaSessionModule)
- }
-}
diff --git a/core/runner/SessionHandler.go b/core/runner/Sessions.go
similarity index 76%
rename from core/runner/SessionHandler.go
rename to core/runner/Sessions.go
index b423b30..e73de0e 100644
--- a/core/runner/SessionHandler.go
+++ b/core/runner/Sessions.go
@@ -3,6 +3,7 @@ package runner
import (
"github.com/valyala/fasthttp"
+ "Moonshark/core/runner/sandbox"
"Moonshark/core/sessions"
"Moonshark/core/utils/logger"
@@ -40,9 +41,6 @@ func WithSessionManager(manager *sessions.SessionManager) RunnerOption {
return func(r *Runner) {
handler := NewSessionHandler(manager)
- // Register the session module
- RegisterCoreModule("session", SessionModuleInitFunc())
-
// Add hooks to the runner
r.AddInitHook(handler.preRequestHook)
r.AddFinalizeHook(handler.postRequestHook)
@@ -140,8 +138,10 @@ func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, resu
session.Set(k, v)
}
+ h.manager.SaveSession(session)
+
// Add session cookie to result if it's an HTTP response
- if httpResp, ok := result.(*HTTPResponse); ok {
+ if httpResp, ok := result.(*sandbox.HTTPResponse); ok {
h.addSessionCookie(httpResp, modifiedID)
}
@@ -150,7 +150,7 @@ func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, resu
}
// addSessionCookie adds a session cookie to an HTTP response
-func (h *SessionHandler) addSessionCookie(resp *HTTPResponse, sessionID string) {
+func (h *SessionHandler) addSessionCookie(resp *sandbox.HTTPResponse, sessionID string) {
// Get cookie options
opts := h.manager.CookieOptions()
@@ -184,3 +184,60 @@ func (h *SessionHandler) addSessionCookie(resp *HTTPResponse, sessionID string)
resp.Cookies = append(resp.Cookies, cookie)
}
+
+// GetSessionData extracts session data from Lua state
+func GetSessionData(state *luajit.State) (string, map[string]any, bool) {
+ // Check if session was modified
+ state.GetGlobal("__session_modified")
+ modified := state.ToBoolean(-1)
+ state.Pop(1)
+
+ if !modified {
+ return "", nil, false
+ }
+
+ // Get session ID
+ state.GetGlobal("__session_id")
+ sessionID := state.ToString(-1)
+ state.Pop(1)
+
+ // Get session data
+ state.GetGlobal("__session_data")
+ if !state.IsTable(-1) {
+ state.Pop(1)
+ return sessionID, nil, false
+ }
+
+ data, err := state.ToTable(-1)
+ state.Pop(1)
+
+ if err != nil {
+ logger.Error("Failed to extract session data: %v", err)
+ return sessionID, nil, false
+ }
+
+ return sessionID, data, true
+}
+
+// SetSessionData sets session data in Lua state
+func SetSessionData(state *luajit.State, sessionID string, data map[string]any) error {
+ // Set session ID
+ state.PushString(sessionID)
+ state.SetGlobal("__session_id")
+
+ // Set session data
+ if data == nil {
+ data = make(map[string]any)
+ }
+
+ if err := state.PushTable(data); err != nil {
+ return err
+ }
+ state.SetGlobal("__session_data")
+
+ // Reset modification flag
+ state.PushBoolean(false)
+ state.SetGlobal("__session_modified")
+
+ return nil
+}
diff --git a/core/runner/sandbox/Embed.go b/core/runner/sandbox/Embed.go
new file mode 100644
index 0000000..944f9d8
--- /dev/null
+++ b/core/runner/sandbox/Embed.go
@@ -0,0 +1,98 @@
+package sandbox
+
+import (
+ _ "embed"
+
+ "Moonshark/core/utils/logger"
+
+ luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
+)
+
+//go:embed lua/sandbox.lua
+var sandboxLua string
+
+// InitializeSandbox loads the embedded Lua sandbox code into a Lua state
+func InitializeSandbox(state *luajit.State) error {
+ // Compile once, use many times
+ bytecodeOnce.Do(precompileSandbox)
+
+ if sandboxBytecode != nil {
+ logger.Debug("Loading sandbox.lua from precompiled bytecode")
+ return state.LoadAndRunBytecode(sandboxBytecode, "sandbox.lua")
+ }
+
+ // Fallback if compilation failed
+ logger.Warning("Using non-precompiled sandbox.lua (bytecode compilation failed)")
+ return state.DoString(sandboxLua)
+}
+
+// ModuleInitializers stores initializer functions for core modules
+type ModuleInitializers struct {
+ HTTP func(*luajit.State) error
+ Util func(*luajit.State) error
+ Session func(*luajit.State) error
+ Cookie func(*luajit.State) error
+ CSRF func(*luajit.State) error
+}
+
+// DefaultInitializers returns the default set of initializers
+func DefaultInitializers() *ModuleInitializers {
+ return &ModuleInitializers{
+ HTTP: func(state *luajit.State) error {
+ // Register the native Go function first
+ if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
+ logger.Error("[HTTP Module] Failed to register __http_request function: %v", err)
+ return err
+ }
+ return nil
+ },
+ Util: func(state *luajit.State) error {
+ // Register util functions
+ return RegisterModule(state, "util", UtilModuleFunctions())
+ },
+ Session: func(state *luajit.State) error {
+ // Session doesn't need special initialization
+ return nil
+ },
+ Cookie: func(state *luajit.State) error {
+ // Cookie doesn't need special initialization
+ return nil
+ },
+ CSRF: func(state *luajit.State) error {
+ // CSRF doesn't need special initialization
+ return nil
+ },
+ }
+}
+
+// InitializeAll initializes all modules in the Lua state
+func InitializeAll(state *luajit.State, initializers *ModuleInitializers) error {
+ // Set up dependencies first
+ if err := initializers.Util(state); err != nil {
+ return err
+ }
+
+ if err := initializers.HTTP(state); err != nil {
+ return err
+ }
+
+ // Load the embedded sandbox code
+ if err := InitializeSandbox(state); err != nil {
+ return err
+ }
+
+ // Initialize the rest of the modules
+ if err := initializers.Session(state); err != nil {
+ return err
+ }
+
+ if err := initializers.Cookie(state); err != nil {
+ return err
+ }
+
+ if err := initializers.CSRF(state); err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/core/runner/HttpModule.go b/core/runner/sandbox/Http.go
similarity index 66%
rename from core/runner/HttpModule.go
rename to core/runner/sandbox/Http.go
index ea28e9f..c207f70 100644
--- a/core/runner/HttpModule.go
+++ b/core/runner/sandbox/Http.go
@@ -1,4 +1,4 @@
-package runner
+package sandbox
import (
"context"
@@ -28,7 +28,7 @@ type HTTPResponse struct {
// Response pool to reduce allocations
var responsePool = sync.Pool{
- New: func() interface{} {
+ New: func() any {
return &HTTPResponse{
Status: 200,
Headers: make(map[string]string, 8), // Pre-allocate with reasonable capacity
@@ -37,36 +37,6 @@ var responsePool = sync.Pool{
},
}
-// NewHTTPResponse creates a default HTTP response, potentially reusing one from the pool
-func NewHTTPResponse() *HTTPResponse {
- return responsePool.Get().(*HTTPResponse)
-}
-
-// ReleaseResponse returns the response to the pool after clearing its values
-func ReleaseResponse(resp *HTTPResponse) {
- if resp == nil {
- return
- }
-
- // Clear all values to prevent data leakage
- resp.Status = 200 // Reset to default
-
- // Clear headers
- for k := range resp.Headers {
- delete(resp.Headers, k)
- }
-
- // Clear cookies
- resp.Cookies = resp.Cookies[:0] // Keep capacity but set length to 0
-
- // Clear body
- resp.Body = nil
-
- responsePool.Put(resp)
-}
-
-// ---------- HTTP CLIENT FUNCTIONALITY ----------
-
// Default HTTP client with sensible timeout
var defaultFastClient fasthttp.Client = fasthttp.Client{
MaxConnsPerHost: 1024,
@@ -96,8 +66,256 @@ var DefaultHTTPClientConfig = HTTPClientConfig{
AllowRemote: true,
}
-// Function name constant to ensure consistency
-const httpRequestFuncName = "__http_request"
+// NewHTTPResponse creates a default HTTP response, potentially reusing one from the pool
+func NewHTTPResponse() *HTTPResponse {
+ return responsePool.Get().(*HTTPResponse)
+}
+
+// ReleaseResponse returns the response to the pool after clearing its values
+func ReleaseResponse(resp *HTTPResponse) {
+ if resp == nil {
+ return
+ }
+
+ // Clear all values to prevent data leakage
+ resp.Status = 200 // Reset to default
+
+ // Clear headers
+ for k := range resp.Headers {
+ delete(resp.Headers, k)
+ }
+
+ // Clear cookies
+ resp.Cookies = resp.Cookies[:0] // Keep capacity but set length to 0
+
+ // Clear body
+ resp.Body = nil
+
+ responsePool.Put(resp)
+}
+
+// HTTPModuleInitFunc returns an initializer function for the HTTP module
+func HTTPModuleInitFunc() func(*luajit.State) error {
+ return func(state *luajit.State) error {
+ // Register the native Go function first
+ if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
+ logger.Error("[HTTP Module] Failed to register __http_request function")
+ logger.ErrorCont("%v", err)
+ return err
+ }
+
+ // Set up default HTTP client configuration
+ setupHTTPClientConfig(state)
+
+ return nil
+ }
+}
+
+// Helper to set up HTTP client config
+func setupHTTPClientConfig(state *luajit.State) {
+ state.NewTable()
+
+ state.PushNumber(float64(DefaultHTTPClientConfig.MaxTimeout / time.Second))
+ state.SetField(-2, "max_timeout")
+
+ state.PushNumber(float64(DefaultHTTPClientConfig.DefaultTimeout / time.Second))
+ state.SetField(-2, "default_timeout")
+
+ state.PushNumber(float64(DefaultHTTPClientConfig.MaxResponseSize))
+ state.SetField(-2, "max_response_size")
+
+ state.PushBoolean(DefaultHTTPClientConfig.AllowRemote)
+ state.SetField(-2, "allow_remote")
+
+ state.SetGlobal("__http_client_config")
+}
+
+// GetHTTPResponse extracts the HTTP response from Lua state
+func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) {
+ response := NewHTTPResponse()
+
+ // Get response table
+ state.GetGlobal("__http_responses")
+ if state.IsNil(-1) {
+ state.Pop(1)
+ ReleaseResponse(response) // Return unused response to pool
+ return nil, false
+ }
+
+ // Check for response at thread index
+ state.PushNumber(1)
+ state.GetTable(-2)
+ if state.IsNil(-1) {
+ state.Pop(2)
+ ReleaseResponse(response) // Return unused response to pool
+ return nil, false
+ }
+
+ // Get status
+ state.GetField(-1, "status")
+ if state.IsNumber(-1) {
+ response.Status = int(state.ToNumber(-1))
+ }
+ state.Pop(1)
+
+ // Get headers
+ state.GetField(-1, "headers")
+ if state.IsTable(-1) {
+ // Iterate through headers table
+ state.PushNil() // Start iteration
+ for state.Next(-2) {
+ // Stack has key at -2 and value at -1
+ if state.IsString(-2) && state.IsString(-1) {
+ key := state.ToString(-2)
+ value := state.ToString(-1)
+ response.Headers[key] = value
+ }
+ state.Pop(1) // Pop value, leave key for next iteration
+ }
+ }
+ state.Pop(1)
+
+ // Get cookies
+ state.GetField(-1, "cookies")
+ if state.IsTable(-1) {
+ // Iterate through cookies array
+ length := state.GetTableLength(-1)
+ for i := 1; i <= length; i++ {
+ state.PushNumber(float64(i))
+ state.GetTable(-2)
+
+ if state.IsTable(-1) {
+ cookie := extractCookie(state)
+ if cookie != nil {
+ response.Cookies = append(response.Cookies, cookie)
+ }
+ }
+ state.Pop(1)
+ }
+ }
+ state.Pop(1)
+
+ // Clean up
+ state.Pop(2) // Pop response table and __http_responses
+
+ return response, true
+}
+
+// ApplyHTTPResponse applies an HTTP response to a fasthttp.RequestCtx
+func ApplyHTTPResponse(httpResp *HTTPResponse, ctx *fasthttp.RequestCtx) {
+ // Set status code
+ ctx.SetStatusCode(httpResp.Status)
+
+ // Set headers
+ for name, value := range httpResp.Headers {
+ ctx.Response.Header.Set(name, value)
+ }
+
+ // Set cookies
+ for _, cookie := range httpResp.Cookies {
+ ctx.Response.Header.SetCookie(cookie)
+ }
+
+ // Process the body based on its type
+ if httpResp.Body == nil {
+ return
+ }
+
+ // Set body based on type
+ switch body := httpResp.Body.(type) {
+ case string:
+ ctx.SetBodyString(body)
+ case []byte:
+ ctx.SetBody(body)
+ case map[string]any, []any, []float64, []string, []int:
+ // Marshal JSON using a buffer from the pool
+ buf := bytebufferpool.Get()
+ defer bytebufferpool.Put(buf)
+
+ if err := json.NewEncoder(buf).Encode(body); err == nil {
+ // Set content type if not already set
+ if len(ctx.Response.Header.ContentType()) == 0 {
+ ctx.Response.Header.SetContentType("application/json")
+ }
+ ctx.SetBody(buf.Bytes())
+ } else {
+ // Fallback
+ ctx.SetBodyString(fmt.Sprintf("%v", body))
+ }
+ default:
+ // Default to string representation
+ ctx.SetBodyString(fmt.Sprintf("%v", body))
+ }
+}
+
+// extractCookie grabs cookies from the Lua state
+func extractCookie(state *luajit.State) *fasthttp.Cookie {
+ cookie := fasthttp.AcquireCookie()
+
+ // Get name
+ state.GetField(-1, "name")
+ if !state.IsString(-1) {
+ state.Pop(1)
+ fasthttp.ReleaseCookie(cookie)
+ return nil // Name is required
+ }
+ cookie.SetKey(state.ToString(-1))
+ state.Pop(1)
+
+ // Get value
+ state.GetField(-1, "value")
+ if state.IsString(-1) {
+ cookie.SetValue(state.ToString(-1))
+ }
+ state.Pop(1)
+
+ // Get path
+ state.GetField(-1, "path")
+ if state.IsString(-1) {
+ cookie.SetPath(state.ToString(-1))
+ } else {
+ cookie.SetPath("/") // Default path
+ }
+ state.Pop(1)
+
+ // Get domain
+ state.GetField(-1, "domain")
+ if state.IsString(-1) {
+ cookie.SetDomain(state.ToString(-1))
+ }
+ state.Pop(1)
+
+ // Get expires
+ state.GetField(-1, "expires")
+ if state.IsNumber(-1) {
+ expiry := int64(state.ToNumber(-1))
+ cookie.SetExpire(time.Unix(expiry, 0))
+ }
+ state.Pop(1)
+
+ // Get max age
+ state.GetField(-1, "max_age")
+ if state.IsNumber(-1) {
+ cookie.SetMaxAge(int(state.ToNumber(-1)))
+ }
+ state.Pop(1)
+
+ // Get secure
+ state.GetField(-1, "secure")
+ if state.IsBoolean(-1) {
+ cookie.SetSecure(state.ToBoolean(-1))
+ }
+ state.Pop(1)
+
+ // Get http only
+ state.GetField(-1, "http_only")
+ if state.IsBoolean(-1) {
+ cookie.SetHTTPOnly(state.ToBoolean(-1))
+ }
+ state.Pop(1)
+
+ return cookie
+}
// httpRequest makes an HTTP request and returns the result to Lua
func httpRequest(state *luajit.State) int {
@@ -360,372 +578,3 @@ func httpRequest(state *luajit.State) int {
return 1
}
-
-// HTTPModuleInitFunc returns an initializer function for the HTTP module
-func HTTPModuleInitFunc() StateInitFunc {
- return func(state *luajit.State) error {
- // CRITICAL: Register the native Go function first
- // This must be done BEFORE any Lua code that references it
- if err := state.RegisterGoFunction(httpRequestFuncName, httpRequest); err != nil {
- logger.Error("[HTTP Module] Failed to register __http_request function")
- logger.ErrorCont("%v", err)
- return err
- }
-
- // Set up default HTTP client configuration
- setupHTTPClientConfig(state)
-
- // Initialize Lua HTTP module
- if err := state.DoString(LuaHTTPModule); err != nil {
- logger.Error("[HTTP Module] Failed to initialize HTTP module Lua code")
- logger.ErrorCont("%v", err)
- return err
- }
-
- // Verify HTTP client functions are available
- verifyHTTPClient(state)
-
- return nil
- }
-}
-
-// Helper to set up HTTP client config
-func setupHTTPClientConfig(state *luajit.State) {
- state.NewTable()
-
- state.PushNumber(float64(DefaultHTTPClientConfig.MaxTimeout / time.Second))
- state.SetField(-2, "max_timeout")
-
- state.PushNumber(float64(DefaultHTTPClientConfig.DefaultTimeout / time.Second))
- state.SetField(-2, "default_timeout")
-
- state.PushNumber(float64(DefaultHTTPClientConfig.MaxResponseSize))
- state.SetField(-2, "max_response_size")
-
- state.PushBoolean(DefaultHTTPClientConfig.AllowRemote)
- state.SetField(-2, "allow_remote")
-
- state.SetGlobal("__http_client_config")
-}
-
-// GetHTTPResponse extracts the HTTP response from Lua state
-func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) {
- response := NewHTTPResponse()
-
- // Get response table
- state.GetGlobal("__http_responses")
- if state.IsNil(-1) {
- state.Pop(1)
- ReleaseResponse(response) // Return unused response to pool
- return nil, false
- }
-
- // Check for response at thread index
- state.PushNumber(1)
- state.GetTable(-2)
- if state.IsNil(-1) {
- state.Pop(2)
- ReleaseResponse(response) // Return unused response to pool
- return nil, false
- }
-
- // Get status
- state.GetField(-1, "status")
- if state.IsNumber(-1) {
- response.Status = int(state.ToNumber(-1))
- }
- state.Pop(1)
-
- // Get headers
- state.GetField(-1, "headers")
- if state.IsTable(-1) {
- // Iterate through headers table
- state.PushNil() // Start iteration
- for state.Next(-2) {
- // Stack has key at -2 and value at -1
- if state.IsString(-2) && state.IsString(-1) {
- key := state.ToString(-2)
- value := state.ToString(-1)
- response.Headers[key] = value
- }
- state.Pop(1) // Pop value, leave key for next iteration
- }
- }
- state.Pop(1)
-
- // Get cookies
- state.GetField(-1, "cookies")
- if state.IsTable(-1) {
- // Iterate through cookies array
- length := state.GetTableLength(-1)
- for i := 1; i <= length; i++ {
- state.PushNumber(float64(i))
- state.GetTable(-2)
-
- if state.IsTable(-1) {
- cookie := extractCookie(state)
- if cookie != nil {
- response.Cookies = append(response.Cookies, cookie)
- }
- }
- state.Pop(1)
- }
- }
- state.Pop(1)
-
- // Clean up
- state.Pop(2) // Pop response table and __http_responses
-
- return response, true
-}
-
-// ApplyHTTPResponse applies an HTTP response to a fasthttp.RequestCtx
-func ApplyHTTPResponse(httpResp *HTTPResponse, ctx *fasthttp.RequestCtx) {
- // Set status code
- ctx.SetStatusCode(httpResp.Status)
-
- // Set headers
- for name, value := range httpResp.Headers {
- ctx.Response.Header.Set(name, value)
- }
-
- // Set cookies
- for _, cookie := range httpResp.Cookies {
- ctx.Response.Header.SetCookie(cookie)
- }
-
- // Process the body based on its type
- if httpResp.Body == nil {
- return
- }
-
- // Set body based on type
- switch body := httpResp.Body.(type) {
- case string:
- ctx.SetBodyString(body)
- case []byte:
- ctx.SetBody(body)
- case map[string]any, []any, []float64, []string, []int:
- // Marshal JSON using a buffer from the pool
- buf := bytebufferpool.Get()
- defer bytebufferpool.Put(buf)
-
- if err := json.NewEncoder(buf).Encode(body); err == nil {
- // Set content type if not already set
- if len(ctx.Response.Header.ContentType()) == 0 {
- ctx.Response.Header.SetContentType("application/json")
- }
- ctx.SetBody(buf.Bytes())
- } else {
- // Fallback
- ctx.SetBodyString(fmt.Sprintf("%v", body))
- }
- default:
- // Default to string representation
- ctx.SetBodyString(fmt.Sprintf("%v", body))
- }
-}
-
-// WithHTTPClientConfig creates a runner option to configure the HTTP client
-func WithHTTPClientConfig(config HTTPClientConfig) RunnerOption {
- return func(r *Runner) {
- // Store the config to be applied during initialization
- r.AddModule("__http_client_config", map[string]any{
- "max_timeout": float64(config.MaxTimeout / time.Second),
- "default_timeout": float64(config.DefaultTimeout / time.Second),
- "max_response_size": float64(config.MaxResponseSize),
- "allow_remote": config.AllowRemote,
- })
- }
-}
-
-// RestrictHTTPToLocalhost is a convenience function to restrict HTTP client
-// to localhost connections only
-func RestrictHTTPToLocalhost() RunnerOption {
- return WithHTTPClientConfig(HTTPClientConfig{
- MaxTimeout: DefaultHTTPClientConfig.MaxTimeout,
- DefaultTimeout: DefaultHTTPClientConfig.DefaultTimeout,
- MaxResponseSize: DefaultHTTPClientConfig.MaxResponseSize,
- AllowRemote: false,
- })
-}
-
-// Verify that HTTP client is properly set up
-func verifyHTTPClient(state *luajit.State) {
- // Get the client table
- state.GetGlobal("http")
- if !state.IsTable(-1) {
- logger.Warning("[HTTP Module] 'http' is not a table")
- state.Pop(1)
- return
- }
-
- state.GetField(-1, "client")
- if !state.IsTable(-1) {
- logger.Warning("[HTTP Module] 'http.client' is not a table")
- state.Pop(2)
- return
- }
-
- // Check for get function
- state.GetField(-1, "get")
- if !state.IsFunction(-1) {
- logger.Warning("[HTTP Module] 'http.client.get' is not a function")
- }
- state.Pop(1)
-
- // Check for the request function
- state.GetField(-1, "request")
- if !state.IsFunction(-1) {
- logger.Warning("[HTTP Module] 'http.client.request' is not a function")
- }
- state.Pop(3) // Pop request, client, http
-}
-
-const LuaHTTPModule = `
--- Table to store response data
-__http_responses = {}
-
--- HTTP module implementation
-local http = {
- -- Set HTTP status code
- set_status = function(code)
- if type(code) ~= "number" then
- error("http.set_status: status code must be a number", 2)
- end
-
- local resp = __http_responses[1] or {}
- resp.status = code
- __http_responses[1] = resp
- end,
-
- -- Set HTTP header
- set_header = function(name, value)
- if type(name) ~= "string" or type(value) ~= "string" then
- error("http.set_header: name and value must be strings", 2)
- end
-
- local resp = __http_responses[1] or {}
- resp.headers = resp.headers or {}
- resp.headers[name] = value
- __http_responses[1] = resp
- end,
-
- -- Set content type; set_header helper
- set_content_type = function(content_type)
- http.set_header("Content-Type", content_type)
- end,
-
- -- HTTP client submodule
- client = {
- -- Generic request function
- request = function(method, url, body, options)
- if type(method) ~= "string" then
- error("http.client.request: method must be a string", 2)
- end
- if type(url) ~= "string" then
- error("http.client.request: url must be a string", 2)
- end
-
- -- Call native implementation (this is the critical part)
- local result = __http_request(method, url, body, options)
- return result
- end,
-
- -- Simple GET request
- get = function(url, options)
- return http.client.request("GET", url, nil, options)
- end,
-
- -- Simple POST request with automatic content-type
- post = function(url, body, options)
- options = options or {}
- return http.client.request("POST", url, body, options)
- end,
-
- -- Simple PUT request with automatic content-type
- put = function(url, body, options)
- options = options or {}
- return http.client.request("PUT", url, body, options)
- end,
-
- -- Simple DELETE request
- delete = function(url, options)
- return http.client.request("DELETE", url, nil, options)
- end,
-
- -- Simple PATCH request
- patch = function(url, body, options)
- options = options or {}
- return http.client.request("PATCH", url, body, options)
- end,
-
- -- Simple HEAD request
- head = function(url, options)
- options = options or {}
- local old_options = options
- options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query}
- local response = http.client.request("HEAD", url, nil, options)
- return response
- end,
-
- -- Simple OPTIONS request
- options = function(url, options)
- return http.client.request("OPTIONS", url, nil, options)
- end,
-
- -- Shorthand function to directly get JSON
- get_json = function(url, options)
- options = options or {}
- local response = http.client.get(url, options)
- if response.ok and response.json then
- return response.json
- end
- return nil, response
- end,
-
- -- Utility to build a URL with query parameters
- build_url = function(base_url, params)
- if not params or type(params) ~= "table" then
- return base_url
- end
-
- local query = {}
- for k, v in pairs(params) do
- if type(v) == "table" then
- for _, item in ipairs(v) do
- table.insert(query, k .. "=" .. tostring(item))
- end
- else
- table.insert(query, k .. "=" .. tostring(v))
- end
- end
-
- if #query > 0 then
- if base_url:find("?") then
- return base_url .. "&" .. table.concat(query, "&")
- else
- return base_url .. "?" .. table.concat(query, "&")
- end
- end
-
- return base_url
- end
- }
-}
-
--- Install HTTP module
-_G.http = http
-
--- Clear previous responses when executing scripts
-local old_execute_script = __execute_script
-if old_execute_script then
- __execute_script = function(fn, ctx)
- -- Clear previous response
- __http_responses[1] = nil
-
- -- Execute original function
- return old_execute_script(fn, ctx)
- end
-end
-`
diff --git a/core/runner/sandbox/Modules.go b/core/runner/sandbox/Modules.go
new file mode 100644
index 0000000..c9a9764
--- /dev/null
+++ b/core/runner/sandbox/Modules.go
@@ -0,0 +1,86 @@
+package sandbox
+
+import (
+ "Moonshark/core/utils/logger"
+
+ luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
+)
+
+// ModuleFunc is a function that returns a map of module functions
+type ModuleFunc func() map[string]luajit.GoFunction
+
+// RegisterModule registers a map of functions as a Lua module
+func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.GoFunction) error {
+ // Create a new table for the module
+ state.NewTable()
+
+ // Add each function to the module table
+ for fname, f := range funcs {
+ // Push function name
+ state.PushString(fname)
+
+ // Push function
+ if err := state.PushGoFunction(f); err != nil {
+ state.Pop(1) // Pop table
+ return err
+ }
+
+ // Set table[fname] = f
+ state.SetTable(-3)
+ }
+
+ // Register the module globally
+ state.SetGlobal(name)
+ return nil
+}
+
+// CombineInitFuncs combines multiple state initializer functions into one
+func CombineInitFuncs(funcs ...func(*luajit.State) error) func(*luajit.State) error {
+ return func(state *luajit.State) error {
+ for _, f := range funcs {
+ if f != nil {
+ if err := f(state); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+ }
+}
+
+// ModuleInitFunc creates a state initializer that registers multiple modules
+func ModuleInitFunc(modules map[string]ModuleFunc) func(*luajit.State) error {
+ return func(state *luajit.State) error {
+ for name, moduleFunc := range modules {
+ if err := RegisterModule(state, name, moduleFunc()); err != nil {
+ logger.Error("Failed to register module %s: %v", name, err)
+ return err
+ }
+ }
+ return nil
+ }
+}
+
+// RegisterLuaCode registers a Lua code snippet as a module
+func RegisterLuaCode(state *luajit.State, code string) error {
+ return state.DoString(code)
+}
+
+// RegisterLuaCodeInitFunc returns a StateInitFunc that registers Lua code
+func RegisterLuaCodeInitFunc(code string) func(*luajit.State) error {
+ return func(state *luajit.State) error {
+ return RegisterLuaCode(state, code)
+ }
+}
+
+// RegisterLuaModuleInitFunc returns a StateInitFunc that registers a Lua module
+func RegisterLuaModuleInitFunc(name string, code string) func(*luajit.State) error {
+ return func(state *luajit.State) error {
+ // Create name = {} global
+ state.NewTable()
+ state.SetGlobal(name)
+
+ // Then run the module code which will populate it
+ return state.DoString(code)
+ }
+}
diff --git a/core/runner/Sandbox.go b/core/runner/sandbox/Sandbox.go
similarity index 73%
rename from core/runner/Sandbox.go
rename to core/runner/sandbox/Sandbox.go
index 3f01b47..b502df0 100644
--- a/core/runner/Sandbox.go
+++ b/core/runner/sandbox/Sandbox.go
@@ -1,4 +1,4 @@
-package runner
+package sandbox
import (
"fmt"
@@ -6,24 +6,52 @@ import (
"github.com/goccy/go-json"
"github.com/valyala/bytebufferpool"
+ "github.com/valyala/fasthttp"
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
+// Global bytecode cache to improve performance
+var (
+ sandboxBytecode []byte
+ bytecodeOnce sync.Once
+)
+
+// precompileSandbox compiles the sandbox.lua code to bytecode once
+func precompileSandbox() {
+ tempState := luajit.New()
+ if tempState == nil {
+ logger.Error("Failed to create temporary Lua state for bytecode compilation")
+ return
+ }
+ defer tempState.Close()
+ defer tempState.Cleanup()
+
+ var err error
+ sandboxBytecode, err = tempState.CompileBytecode(sandboxLua, "sandbox.lua")
+ if err != nil {
+ logger.Error("Failed to precompile sandbox.lua: %v", err)
+ } else {
+ logger.Debug("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(sandboxBytecode))
+ }
+}
+
// Sandbox provides a secure execution environment for Lua scripts
type Sandbox struct {
- modules map[string]any // Custom modules for environment
- debug bool // Enable debug output
- mu sync.RWMutex // Protects modules
+ modules map[string]any // Custom modules for environment
+ debug bool // Enable debug output
+ mu sync.RWMutex // Protects modules
+ initializers *ModuleInitializers // Module initializers
}
// NewSandbox creates a new sandbox environment
func NewSandbox() *Sandbox {
return &Sandbox{
- modules: make(map[string]any, 8), // Pre-allocate with reasonable capacity
- debug: false,
+ modules: make(map[string]any, 8), // Pre-allocate with reasonable capacity
+ debug: false,
+ initializers: DefaultInitializers(),
}
}
@@ -39,7 +67,7 @@ func (s *Sandbox) debugLog(format string, args ...interface{}) {
}
}
-// debugLog logs a message if debug mode is enabled
+// debugLogCont logs a continuation message if debug mode is enabled
func (s *Sandbox) debugLogCont(format string, args ...interface{}) {
if s.debug {
logger.DebugCont(format, args...)
@@ -60,19 +88,27 @@ func (s *Sandbox) Setup(state *luajit.State, stateIndex int) error {
verbose := stateIndex == 0
if verbose {
- s.debugLog("is setting up...")
+ s.debugLog("Setting up sandbox...")
}
- // Register modules in the global environment
+ // Initialize modules with the embedded sandbox code
+ if err := InitializeAll(state, s.initializers); err != nil {
+ if verbose {
+ s.debugLog("Failed to initialize sandbox: %v", err)
+ }
+ return err
+ }
+
+ // Register custom modules in the global environment
s.mu.RLock()
for name, module := range s.modules {
if verbose {
- s.debugLog("is registering module: %s", name)
+ s.debugLog("Registering module: %s", name)
}
if err := state.PushValue(module); err != nil {
s.mu.RUnlock()
if verbose {
- s.debugLog("failed to register module %s: %v", name, err)
+ s.debugLog("Failed to register module %s: %v", name, err)
}
return err
}
@@ -80,60 +116,8 @@ func (s *Sandbox) Setup(state *luajit.State, stateIndex int) error {
}
s.mu.RUnlock()
- // Initialize environment setup
- err := state.DoString(`
- -- Global tables for response handling
- __http_responses = __http_responses or {}
-
- -- Create environment inheriting from _G
- function __create_env(ctx)
- -- Create environment with metatable inheriting from _G
- local env = setmetatable({}, {__index = _G})
-
- -- Add context if provided
- if ctx then
- env.ctx = ctx
- end
-
- -- Add proper require function to this environment
- if __setup_require then
- __setup_require(env)
- end
-
- return env
- end
-
- -- Execute script with clean environment
- function __execute_script(fn, ctx)
- -- Clear previous responses
- __http_responses[1] = nil
-
- -- Create environment
- local env = __create_env(ctx)
-
- -- Set environment for function
- setfenv(fn, env)
-
- -- Execute with protected call
- local ok, result = pcall(fn)
- if not ok then
- error(result, 0)
- end
-
- return result
- end
- `)
-
- if err != nil {
- if verbose {
- s.debugLog("failed to set up...")
- s.debugLogCont("%v", err)
- }
- return err
- }
-
if verbose {
- s.debugLogCont("Complete")
+ s.debugLogCont("Sandbox setup complete")
}
return nil
}
@@ -152,6 +136,14 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a
return s.OptimizedExecute(state, bytecode, nil)
}
+// Context represents execution context for a Lua script
+type Context struct {
+ // Values stores any context values (route params, HTTP request info, etc.)
+ Values map[string]any
+ // RequestCtx for HTTP requests
+ RequestCtx *fasthttp.RequestCtx
+}
+
// OptimizedExecute runs bytecode with a fasthttp context if available
func (s *Sandbox) OptimizedExecute(state *luajit.State, bytecode []byte, ctx *Context) (any, error) {
// Use a buffer from the pool for any string operations
diff --git a/core/runner/UtilModule.go b/core/runner/sandbox/Utils.go
similarity index 85%
rename from core/runner/UtilModule.go
rename to core/runner/sandbox/Utils.go
index 9141af5..75863a9 100644
--- a/core/runner/UtilModule.go
+++ b/core/runner/sandbox/Utils.go
@@ -1,4 +1,4 @@
-package runner
+package sandbox
import (
"crypto/rand"
@@ -9,6 +9,20 @@ import (
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
+// UtilModuleInitFunc returns an initializer for the util module
+func UtilModuleInitFunc() func(*luajit.State) error {
+ return func(state *luajit.State) error {
+ return RegisterModule(state, "util", UtilModuleFunctions())
+ }
+}
+
+// UtilModuleFunctions returns all functions for the util module
+func UtilModuleFunctions() map[string]luajit.GoFunction {
+ return map[string]luajit.GoFunction{
+ "generate_token": GenerateToken,
+ }
+}
+
// GenerateToken creates a cryptographically secure random token
func GenerateToken(s *luajit.State) int {
// Get the length from the Lua arguments (default to 32)
@@ -42,17 +56,3 @@ func GenerateToken(s *luajit.State) int {
s.PushString(token)
return 1 // One return value
}
-
-// UtilModuleFunctions returns all functions for the go module
-func UtilModuleFunctions() map[string]luajit.GoFunction {
- return map[string]luajit.GoFunction{
- "generate_token": GenerateToken,
- }
-}
-
-// UtilModuleInitFunc returns an initializer for the go module
-func UtilModuleInitFunc() StateInitFunc {
- return func(state *luajit.State) error {
- return RegisterModule(state, "util", UtilModuleFunctions())
- }
-}
diff --git a/core/runner/sandbox/lua/sandbox.lua b/core/runner/sandbox/lua/sandbox.lua
new file mode 100644
index 0000000..c755c27
--- /dev/null
+++ b/core/runner/sandbox/lua/sandbox.lua
@@ -0,0 +1,552 @@
+--[[
+Moonshark Lua Sandbox Environment
+
+This file contains all the Lua code needed for the sandbox environment,
+including core modules and utilities. It's designed to be embedded in the
+Go binary at build time.
+]]--
+
+-- Global tables for execution context
+__http_responses = {}
+__module_paths = {}
+__module_bytecode = {}
+__ready_modules = {}
+__session_data = {}
+__session_id = nil
+__session_modified = false
+__env_system = {
+ base_env = {}
+}
+
+-- ======================================================================
+-- CORE SANDBOX FUNCTIONALITY
+-- ======================================================================
+
+-- Create environment inheriting from _G
+function __create_env(ctx)
+ -- Create environment with metatable inheriting from _G
+ local env = setmetatable({}, {__index = _G})
+
+ -- Add context if provided
+ if ctx then
+ env.ctx = ctx
+ end
+
+ -- Add proper require function to this environment
+ if __setup_require then
+ __setup_require(env)
+ end
+
+ return env
+end
+
+-- Execute script with clean environment
+function __execute_script(fn, ctx)
+ -- Clear previous responses
+ __http_responses[1] = nil
+
+ -- Reset session modification flag
+ __session_modified = false
+
+ -- Create environment
+ local env = __create_env(ctx)
+
+ -- Set environment for function
+ setfenv(fn, env)
+
+ -- Execute with protected call
+ local ok, result = pcall(fn)
+ if not ok then
+ error(result, 0)
+ end
+
+ return result
+end
+
+-- ======================================================================
+-- MODULE LOADING SYSTEM
+-- ======================================================================
+
+-- Setup environment-aware require function
+function __setup_require(env)
+ -- Create require function specific to this environment
+ env.require = function(modname)
+ -- Check if already loaded
+ if package.loaded[modname] then
+ return package.loaded[modname]
+ end
+
+ -- Check preloaded modules
+ if __ready_modules[modname] then
+ local loader = package.preload[modname]
+ if loader then
+ -- Set environment for loader
+ setfenv(loader, env)
+
+ -- Execute and store result
+ local result = loader()
+ if result == nil then
+ result = true
+ end
+
+ package.loaded[modname] = result
+ return result
+ end
+ end
+
+ -- Direct file load as fallback
+ if __module_paths[modname] then
+ local path = __module_paths[modname]
+ local chunk, err = loadfile(path)
+ if chunk then
+ setfenv(chunk, env)
+ local result = chunk()
+ if result == nil then
+ result = true
+ end
+ package.loaded[modname] = result
+ return result
+ end
+ end
+
+ -- Full path search as last resort
+ local errors = {}
+ for path in package.path:gmatch("[^;]+") do
+ local file_path = path:gsub("?", modname:gsub("%.", "/"))
+ local chunk, err = loadfile(file_path)
+ if chunk then
+ setfenv(chunk, env)
+ local result = chunk()
+ if result == nil then
+ result = true
+ end
+ package.loaded[modname] = result
+ return result
+ end
+ table.insert(errors, "\tno file '" .. file_path .. "'")
+ end
+
+ error("module '" .. modname .. "' not found:\n" .. table.concat(errors, "\n"), 2)
+ end
+
+ return env
+end
+
+-- ======================================================================
+-- HTTP MODULE
+-- ======================================================================
+
+-- HTTP module implementation
+local http = {
+ -- Set HTTP status code
+ set_status = function(code)
+ if type(code) ~= "number" then
+ error("http.set_status: status code must be a number", 2)
+ end
+
+ local resp = __http_responses[1] or {}
+ resp.status = code
+ __http_responses[1] = resp
+ end,
+
+ -- Set HTTP header
+ set_header = function(name, value)
+ if type(name) ~= "string" or type(value) ~= "string" then
+ error("http.set_header: name and value must be strings", 2)
+ end
+
+ local resp = __http_responses[1] or {}
+ resp.headers = resp.headers or {}
+ resp.headers[name] = value
+ __http_responses[1] = resp
+ end,
+
+ -- Set content type; set_header helper
+ set_content_type = function(content_type)
+ http.set_header("Content-Type", content_type)
+ end,
+
+ -- HTTP client submodule
+ client = {
+ -- Generic request function
+ request = function(method, url, body, options)
+ if type(method) ~= "string" then
+ error("http.client.request: method must be a string", 2)
+ end
+ if type(url) ~= "string" then
+ error("http.client.request: url must be a string", 2)
+ end
+
+ -- Call native implementation
+ local result = __http_request(method, url, body, options)
+ return result
+ end,
+
+ -- Simple GET request
+ get = function(url, options)
+ return http.client.request("GET", url, nil, options)
+ end,
+
+ -- Simple POST request with automatic content-type
+ post = function(url, body, options)
+ options = options or {}
+ return http.client.request("POST", url, body, options)
+ end,
+
+ -- Simple PUT request with automatic content-type
+ put = function(url, body, options)
+ options = options or {}
+ return http.client.request("PUT", url, body, options)
+ end,
+
+ -- Simple DELETE request
+ delete = function(url, options)
+ return http.client.request("DELETE", url, nil, options)
+ end,
+
+ -- Simple PATCH request
+ patch = function(url, body, options)
+ options = options or {}
+ return http.client.request("PATCH", url, body, options)
+ end,
+
+ -- Simple HEAD request
+ head = function(url, options)
+ options = options or {}
+ local old_options = options
+ options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query}
+ local response = http.client.request("HEAD", url, nil, options)
+ return response
+ end,
+
+ -- Simple OPTIONS request
+ options = function(url, options)
+ return http.client.request("OPTIONS", url, nil, options)
+ end,
+
+ -- Shorthand function to directly get JSON
+ get_json = function(url, options)
+ options = options or {}
+ local response = http.client.get(url, options)
+ if response.ok and response.json then
+ return response.json
+ end
+ return nil, response
+ end,
+
+ -- Utility to build a URL with query parameters
+ build_url = function(base_url, params)
+ if not params or type(params) ~= "table" then
+ return base_url
+ end
+
+ local query = {}
+ for k, v in pairs(params) do
+ if type(v) == "table" then
+ for _, item in ipairs(v) do
+ table.insert(query, k .. "=" .. tostring(item))
+ end
+ else
+ table.insert(query, k .. "=" .. tostring(v))
+ end
+ end
+
+ if #query > 0 then
+ if base_url:find("?") then
+ return base_url .. "&" .. table.concat(query, "&")
+ else
+ return base_url .. "?" .. table.concat(query, "&")
+ end
+ end
+
+ return base_url
+ end
+ }
+}
+
+-- ======================================================================
+-- COOKIE MODULE
+-- ======================================================================
+
+-- Cookie module implementation
+local cookie = {
+ -- Set a cookie
+ set = function(name, value, options, ...)
+ if type(name) ~= "string" then
+ error("cookie.set: name must be a string", 2)
+ end
+
+ -- Get or create response
+ local resp = __http_responses[1] or {}
+ resp.cookies = resp.cookies or {}
+ __http_responses[1] = resp
+
+ -- Handle options as table or legacy params
+ local opts = {}
+ if type(options) == "table" then
+ opts = options
+ elseif options ~= nil then
+ -- Legacy support: options is actually 'expires'
+ opts.expires = options
+ -- Check for other legacy params (4th-7th args)
+ local args = {...}
+ if args[1] then opts.path = args[1] end
+ if args[2] then opts.domain = args[2] end
+ if args[3] then opts.secure = args[3] end
+ if args[4] ~= nil then opts.http_only = args[4] end
+ end
+
+ -- Create cookie table
+ local cookie = {
+ name = name,
+ value = value or "",
+ path = opts.path or "/",
+ domain = opts.domain
+ }
+
+ -- Handle expiry
+ if opts.expires then
+ if type(opts.expires) == "number" then
+ if opts.expires > 0 then
+ cookie.max_age = opts.expires
+ local now = os.time()
+ cookie.expires = now + opts.expires
+ elseif opts.expires < 0 then
+ cookie.expires = 1
+ cookie.max_age = 0
+ else
+ -- opts.expires == 0: Session cookie
+ -- Do nothing (omitting both expires and max-age creates a session cookie)
+ end
+ end
+ end
+
+ -- Security flags
+ cookie.secure = (opts.secure ~= false)
+ cookie.http_only = (opts.http_only ~= false)
+
+ -- Store in cookies table
+ local n = #resp.cookies + 1
+ resp.cookies[n] = cookie
+
+ return true
+ end,
+
+ -- Get a cookie value
+ get = function(name)
+ if type(name) ~= "string" then
+ error("cookie.get: name must be a string", 2)
+ end
+
+ -- Access values directly from current environment
+ local env = getfenv(2)
+
+ -- Check if context exists and has cookies
+ if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then
+ return tostring(env.ctx.cookies[name])
+ end
+
+ return nil
+ end,
+
+ -- Remove a cookie
+ remove = function(name, path, domain)
+ if type(name) ~= "string" then
+ error("cookie.remove: name must be a string", 2)
+ end
+
+ -- Create an expired cookie
+ return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain})
+ end
+}
+
+-- ======================================================================
+-- SESSION MODULE
+-- ======================================================================
+
+-- Session module implementation
+local session = {
+ -- Get a session value
+ get = function(key)
+ if type(key) ~= "string" then
+ error("session.get: key must be a string", 2)
+ end
+
+ if __session_data and __session_data[key] then
+ return __session_data[key]
+ end
+
+ return nil
+ end,
+
+ -- Set a session value
+ set = function(key, value)
+ if type(key) ~= "string" then
+ error("session.set: key must be a string", 2)
+ end
+
+ -- Ensure session data table exists
+ __session_data = __session_data or {}
+
+ -- Store value
+ __session_data[key] = value
+
+ -- Mark session as modified
+ __session_modified = true
+
+ return true
+ end,
+
+ -- Delete a session value
+ delete = function(key)
+ if type(key) ~= "string" then
+ error("session.delete: key must be a string", 2)
+ end
+
+ if __session_data then
+ __session_data[key] = nil
+ __session_modified = true
+ end
+
+ return true
+ end,
+
+ -- Clear all session data
+ clear = function()
+ __session_data = {}
+ __session_modified = true
+ return true
+ end,
+
+ -- Get the session ID
+ get_id = function()
+ return __session_id or nil
+ end,
+
+ -- Get all session data
+ get_all = function()
+ local result = {}
+ for k, v in pairs(__session_data or {}) do
+ result[k] = v
+ end
+ return result
+ end,
+
+ -- Check if session has a key
+ has = function(key)
+ if type(key) ~= "string" then
+ error("session.has: key must be a string", 2)
+ end
+
+ return __session_data and __session_data[key] ~= nil
+ end
+}
+
+-- ======================================================================
+-- CSRF MODULE
+-- ======================================================================
+
+-- CSRF protection module
+local csrf = {
+ -- Session key where the token is stored
+ TOKEN_KEY = "_csrf_token",
+
+ -- Default form field name
+ DEFAULT_FIELD = "csrf",
+
+ -- Generate a new CSRF token and store it in the session
+ generate = function(length)
+ -- Default length is 32 characters
+ length = length or 32
+
+ if length < 16 then
+ -- Enforce minimum security
+ length = 16
+ end
+
+ -- Check if we have a session module
+ if not session then
+ error("CSRF protection requires the session module", 2)
+ end
+
+ local token = util.generate_token(length)
+ session.set(csrf.TOKEN_KEY, token)
+ return token
+ end,
+
+ -- Get the current token or generate a new one
+ token = function()
+ -- Get from session if exists
+ local token = session.get(csrf.TOKEN_KEY)
+
+ -- Generate if needed
+ if not token then
+ token = csrf.generate()
+ end
+
+ return token
+ end,
+
+ -- Generate a hidden form field with the CSRF token
+ field = function(field_name)
+ field_name = field_name or csrf.DEFAULT_FIELD
+ local token = csrf.token()
+ return string.format('', field_name, token)
+ end,
+
+ -- Verify a given token against the session token
+ verify = function(token, field_name)
+ field_name = field_name or csrf.DEFAULT_FIELD
+
+ local env = getfenv(2)
+
+ local form = nil
+ if env.ctx and env.ctx.form then
+ form = env.ctx.form
+ else
+ return false
+ end
+
+ token = token or form[field_name]
+ if not token then
+ return false
+ end
+
+ local session_token = session.get(csrf.TOKEN_KEY)
+ if not session_token then
+ return false
+ end
+
+ -- Constant-time comparison to prevent timing attacks
+ -- This is safe since Lua strings are immutable
+ if #token ~= #session_token then
+ return false
+ end
+
+ local result = true
+ for i = 1, #token do
+ if token:sub(i, i) ~= session_token:sub(i, i) then
+ result = false
+ -- Don't break early - continue to prevent timing attacks
+ end
+ end
+
+ return result
+ end
+}
+
+-- ======================================================================
+-- REGISTER MODULES GLOBALLY
+-- ======================================================================
+
+-- Install modules in global scope
+_G.http = http
+_G.cookie = cookie
+_G.session = session
+_G.csrf = csrf
+
+-- Register modules in sandbox base environment
+__env_system.base_env.http = http
+__env_system.base_env.cookie = cookie
+__env_system.base_env.session = session
+__env_system.base_env.csrf = csrf
\ No newline at end of file
diff --git a/core/sessions/Manager.go b/core/sessions/Manager.go
index 8be2a11..7513b91 100644
--- a/core/sessions/Manager.go
+++ b/core/sessions/Manager.go
@@ -87,7 +87,6 @@ func (sm *SessionManager) GetSession(id string) *Session {
func (sm *SessionManager) CreateSession() *Session {
id := generateSessionID()
- // Create new session
session := NewSession(id)
data, _ := json.Marshal(session)
sm.cache.Set([]byte(id), data)
@@ -95,6 +94,12 @@ func (sm *SessionManager) CreateSession() *Session {
return session
}
+// SaveSession persists a session back to the cache
+func (sm *SessionManager) SaveSession(session *Session) {
+ data, _ := json.Marshal(session)
+ sm.cache.Set([]byte(session.ID), data)
+}
+
// DestroySession removes a session
func (sm *SessionManager) DestroySession(id string) {
sm.cache.Del([]byte(id))
diff --git a/core/sessions/Session.go b/core/sessions/Session.go
index fa18609..ec75f7f 100644
--- a/core/sessions/Session.go
+++ b/core/sessions/Session.go
@@ -18,13 +18,13 @@ var (
// Session stores data for a single user session
type Session struct {
- ID string
- Data map[string]any
- CreatedAt time.Time
- UpdatedAt time.Time
- mu sync.RWMutex // Protect concurrent access to Data
- maxValueSize int // Maximum size of individual values in bytes
- totalDataSize int // Track total size of all data
+ ID string `json:"id"`
+ Data map[string]any `json:"data"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+ mu sync.RWMutex `json:"-"`
+ maxValueSize int `json:"max_value_size"`
+ totalDataSize int `json:"total_data_size"`
}
// NewSession creates a new session with the given ID