reorg 2
This commit is contained in:
parent
6f020932c4
commit
c0b493b6bc
|
@ -185,7 +185,7 @@ func (s *Moonshark) initRunner() error {
|
||||||
runner.WithPoolSize(s.Config.Runner.PoolSize),
|
runner.WithPoolSize(s.Config.Runner.PoolSize),
|
||||||
runner.WithLibDirs(s.Config.Dirs.Libs...),
|
runner.WithLibDirs(s.Config.Dirs.Libs...),
|
||||||
runner.WithSessionManager(sessionManager),
|
runner.WithSessionManager(sessionManager),
|
||||||
runner.WithCSRFProtection(),
|
http.WithCSRFProtection(),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add debug option conditionally
|
// Add debug option conditionally
|
||||||
|
|
|
@ -1,12 +1,119 @@
|
||||||
package http
|
package http
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"Moonshark/core/runner"
|
||||||
"Moonshark/core/utils"
|
"Moonshark/core/utils"
|
||||||
"Moonshark/core/utils/logger"
|
"Moonshark/core/utils/logger"
|
||||||
|
"crypto/subtle"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ValidateCSRFToken checks if the CSRF token is valid for a request
|
||||||
|
func ValidateCSRFToken(state *luajit.State, ctx *runner.Context) bool {
|
||||||
|
// Only validate for form submissions
|
||||||
|
method, ok := ctx.Get("method").(string)
|
||||||
|
if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get form data
|
||||||
|
formData, ok := ctx.Get("form").(map[string]any)
|
||||||
|
if !ok || formData == nil {
|
||||||
|
logger.Warning("CSRF validation failed: no form data")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get token from form
|
||||||
|
formToken, ok := formData["csrf"].(string)
|
||||||
|
if !ok || formToken == "" {
|
||||||
|
logger.Warning("CSRF validation failed: no token in form")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get session token
|
||||||
|
state.GetGlobal("session")
|
||||||
|
if state.IsNil(-1) {
|
||||||
|
state.Pop(1)
|
||||||
|
logger.Warning("CSRF validation failed: session module not available")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
state.GetField(-1, "get")
|
||||||
|
if !state.IsFunction(-1) {
|
||||||
|
state.Pop(2)
|
||||||
|
logger.Warning("CSRF validation failed: session.get not available")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
state.PushCopy(-1) // Duplicate function
|
||||||
|
state.PushString("_csrf_token")
|
||||||
|
|
||||||
|
if err := state.Call(1, 1); err != nil {
|
||||||
|
state.Pop(3) // Pop error, function and session table
|
||||||
|
logger.Warning("CSRF validation failed: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.IsNil(-1) {
|
||||||
|
state.Pop(3) // Pop nil, function and session table
|
||||||
|
logger.Warning("CSRF validation failed: no token in session")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionToken := state.ToString(-1)
|
||||||
|
state.Pop(3) // Pop token, function and session table
|
||||||
|
|
||||||
|
// Constant-time comparison to prevent timing attacks
|
||||||
|
return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCSRFProtection creates a runner option to add CSRF protection
|
||||||
|
func WithCSRFProtection() runner.RunnerOption {
|
||||||
|
return func(r *runner.Runner) {
|
||||||
|
r.AddInitHook(func(state *luajit.State, ctx *runner.Context) error {
|
||||||
|
// Get request method
|
||||||
|
method, ok := ctx.Get("method").(string)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only validate for form submissions
|
||||||
|
if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for form data
|
||||||
|
form, ok := ctx.Get("form").(map[string]any)
|
||||||
|
if !ok || form == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate CSRF token
|
||||||
|
if !ValidateCSRFToken(state, ctx) {
|
||||||
|
return ErrCSRFValidationFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error for CSRF validation failure
|
||||||
|
var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"}
|
||||||
|
|
||||||
|
// CSRFError represents a CSRF validation error
|
||||||
|
type CSRFError struct {
|
||||||
|
message string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error implements the error interface
|
||||||
|
func (e *CSRFError) Error() string {
|
||||||
|
return e.message
|
||||||
|
}
|
||||||
|
|
||||||
// HandleCSRFError handles a CSRF validation error
|
// HandleCSRFError handles a CSRF validation error
|
||||||
func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) {
|
func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) {
|
||||||
method := string(ctx.Method())
|
method := string(ctx.Method())
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"Moonshark/core/metadata"
|
"Moonshark/core/metadata"
|
||||||
"Moonshark/core/routers"
|
"Moonshark/core/routers"
|
||||||
"Moonshark/core/runner"
|
"Moonshark/core/runner"
|
||||||
|
"Moonshark/core/runner/sandbox"
|
||||||
"Moonshark/core/utils"
|
"Moonshark/core/utils"
|
||||||
"Moonshark/core/utils/config"
|
"Moonshark/core/utils/config"
|
||||||
"Moonshark/core/utils/logger"
|
"Moonshark/core/utils/logger"
|
||||||
|
@ -226,7 +227,7 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
||||||
|
|
||||||
// Special handling for CSRF error
|
// Special handling for CSRF error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if csrfErr, ok := err.(*runner.CSRFError); ok {
|
if csrfErr, ok := err.(*CSRFError); ok {
|
||||||
logger.Warning("CSRF error executing Lua route: %v", csrfErr)
|
logger.Warning("CSRF error executing Lua route: %v", csrfErr)
|
||||||
HandleCSRFError(ctx, s.errorConfig)
|
HandleCSRFError(ctx, s.errorConfig)
|
||||||
return
|
return
|
||||||
|
@ -258,8 +259,8 @@ func writeResponse(ctx *fasthttp.RequestCtx, result any) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for HTTPResponse type
|
// Check for HTTPResponse type
|
||||||
if httpResp, ok := result.(*runner.HTTPResponse); ok {
|
if httpResp, ok := result.(*sandbox.HTTPResponse); ok {
|
||||||
defer runner.ReleaseResponse(httpResp)
|
defer sandbox.ReleaseResponse(httpResp)
|
||||||
|
|
||||||
// Set response headers
|
// Set response headers
|
||||||
for name, value := range httpResp.Headers {
|
for name, value := range httpResp.Headers {
|
||||||
|
|
|
@ -3,6 +3,8 @@ package runner
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"maps"
|
||||||
|
|
||||||
"github.com/valyala/bytebufferpool"
|
"github.com/valyala/bytebufferpool"
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
@ -24,7 +26,7 @@ type Context struct {
|
||||||
|
|
||||||
// Context pool to reduce allocations
|
// Context pool to reduce allocations
|
||||||
var contextPool = sync.Pool{
|
var contextPool = sync.Pool{
|
||||||
New: func() interface{} {
|
New: func() any {
|
||||||
return &Context{
|
return &Context{
|
||||||
Values: make(map[string]any, 16), // Pre-allocate with reasonable capacity
|
Values: make(map[string]any, 16), // Pre-allocate with reasonable capacity
|
||||||
}
|
}
|
||||||
|
@ -115,9 +117,7 @@ func (c *Context) All() map[string]any {
|
||||||
defer c.mu.RUnlock()
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
result := make(map[string]any, len(c.Values))
|
result := make(map[string]any, len(c.Values))
|
||||||
for k, v := range c.Values {
|
maps.Copy(result, c.Values)
|
||||||
result[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,187 +0,0 @@
|
||||||
package runner
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
)
|
|
||||||
|
|
||||||
// extractCookie grabs cookies from the Lua state
|
|
||||||
func extractCookie(state *luajit.State) *fasthttp.Cookie {
|
|
||||||
cookie := fasthttp.AcquireCookie()
|
|
||||||
|
|
||||||
// Get name
|
|
||||||
state.GetField(-1, "name")
|
|
||||||
if !state.IsString(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
fasthttp.ReleaseCookie(cookie)
|
|
||||||
return nil // Name is required
|
|
||||||
}
|
|
||||||
cookie.SetKey(state.ToString(-1))
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get value
|
|
||||||
state.GetField(-1, "value")
|
|
||||||
if state.IsString(-1) {
|
|
||||||
cookie.SetValue(state.ToString(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get path
|
|
||||||
state.GetField(-1, "path")
|
|
||||||
if state.IsString(-1) {
|
|
||||||
cookie.SetPath(state.ToString(-1))
|
|
||||||
} else {
|
|
||||||
cookie.SetPath("/") // Default path
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get domain
|
|
||||||
state.GetField(-1, "domain")
|
|
||||||
if state.IsString(-1) {
|
|
||||||
cookie.SetDomain(state.ToString(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get expires
|
|
||||||
state.GetField(-1, "expires")
|
|
||||||
if state.IsNumber(-1) {
|
|
||||||
expiry := int64(state.ToNumber(-1))
|
|
||||||
cookie.SetExpire(time.Unix(expiry, 0))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get max age
|
|
||||||
state.GetField(-1, "max_age")
|
|
||||||
if state.IsNumber(-1) {
|
|
||||||
cookie.SetMaxAge(int(state.ToNumber(-1)))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get secure
|
|
||||||
state.GetField(-1, "secure")
|
|
||||||
if state.IsBoolean(-1) {
|
|
||||||
cookie.SetSecure(state.ToBoolean(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get http only
|
|
||||||
state.GetField(-1, "http_only")
|
|
||||||
if state.IsBoolean(-1) {
|
|
||||||
cookie.SetHTTPOnly(state.ToBoolean(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
return cookie
|
|
||||||
}
|
|
||||||
|
|
||||||
// LuaCookieModule provides cookie functionality to Lua scripts
|
|
||||||
const LuaCookieModule = `
|
|
||||||
-- Cookie module implementation
|
|
||||||
local cookie = {
|
|
||||||
-- Set a cookie
|
|
||||||
set = function(name, value, options, ...)
|
|
||||||
if type(name) ~= "string" then
|
|
||||||
error("cookie.set: name must be a string", 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Get or create response
|
|
||||||
local resp = __http_responses[1] or {}
|
|
||||||
resp.cookies = resp.cookies or {}
|
|
||||||
__http_responses[1] = resp
|
|
||||||
|
|
||||||
-- Handle options as table or legacy params
|
|
||||||
local opts = {}
|
|
||||||
if type(options) == "table" then
|
|
||||||
opts = options
|
|
||||||
elseif options ~= nil then
|
|
||||||
-- Legacy support: options is actually 'expires'
|
|
||||||
opts.expires = options
|
|
||||||
-- Check for other legacy params (4th-7th args)
|
|
||||||
local args = {...}
|
|
||||||
if args[1] then opts.path = args[1] end
|
|
||||||
if args[2] then opts.domain = args[2] end
|
|
||||||
if args[3] then opts.secure = args[3] end
|
|
||||||
if args[4] ~= nil then opts.http_only = args[4] end
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Create cookie table
|
|
||||||
local cookie = {
|
|
||||||
name = name,
|
|
||||||
value = value or "",
|
|
||||||
path = opts.path or "/",
|
|
||||||
domain = opts.domain
|
|
||||||
}
|
|
||||||
|
|
||||||
-- Handle expiry
|
|
||||||
if opts.expires then
|
|
||||||
if type(opts.expires) == "number" then
|
|
||||||
if opts.expires > 0 then
|
|
||||||
cookie.max_age = opts.expires
|
|
||||||
local now = os.time()
|
|
||||||
cookie.expires = now + opts.expires
|
|
||||||
elseif opts.expires < 0 then
|
|
||||||
cookie.expires = 1
|
|
||||||
cookie.max_age = 0
|
|
||||||
else
|
|
||||||
-- opts.expires == 0: Session cookie
|
|
||||||
-- Do nothing (omitting both expires and max-age creates a session cookie)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Security flags
|
|
||||||
cookie.secure = (opts.secure ~= false)
|
|
||||||
cookie.http_only = (opts.http_only ~= false)
|
|
||||||
|
|
||||||
-- Store in cookies table
|
|
||||||
local n = #resp.cookies + 1
|
|
||||||
resp.cookies[n] = cookie
|
|
||||||
|
|
||||||
return true
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Get a cookie value
|
|
||||||
get = function(name)
|
|
||||||
if type(name) ~= "string" then
|
|
||||||
error("cookie.get: name must be a string", 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Access values directly from current environment
|
|
||||||
local env = getfenv(2)
|
|
||||||
|
|
||||||
-- Check if context exists and has cookies
|
|
||||||
if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then
|
|
||||||
return tostring(env.ctx.cookies[name])
|
|
||||||
end
|
|
||||||
|
|
||||||
return nil
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Remove a cookie
|
|
||||||
remove = function(name, path, domain)
|
|
||||||
if type(name) ~= "string" then
|
|
||||||
error("cookie.remove: name must be a string", 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Create an expired cookie
|
|
||||||
return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain})
|
|
||||||
end
|
|
||||||
}
|
|
||||||
|
|
||||||
-- Install cookie module
|
|
||||||
_G.cookie = cookie
|
|
||||||
|
|
||||||
-- Make sure the cookie module is accessible in sandbox
|
|
||||||
if __env_system and __env_system.base_env then
|
|
||||||
__env_system.base_env.cookie = cookie
|
|
||||||
end
|
|
||||||
`
|
|
||||||
|
|
||||||
// CookieModuleInitFunc returns an initializer for the cookie module
|
|
||||||
func CookieModuleInitFunc() StateInitFunc {
|
|
||||||
return func(state *luajit.State) error {
|
|
||||||
return state.DoString(LuaCookieModule)
|
|
||||||
}
|
|
||||||
}
|
|
77
core/runner/Cookies.go
Normal file
77
core/runner/Cookies.go
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// extractCookie grabs cookies from the Lua state
|
||||||
|
func extractCookie(state *luajit.State) *fasthttp.Cookie {
|
||||||
|
cookie := fasthttp.AcquireCookie()
|
||||||
|
|
||||||
|
// Get name
|
||||||
|
state.GetField(-1, "name")
|
||||||
|
if !state.IsString(-1) {
|
||||||
|
state.Pop(1)
|
||||||
|
fasthttp.ReleaseCookie(cookie)
|
||||||
|
return nil // Name is required
|
||||||
|
}
|
||||||
|
cookie.SetKey(state.ToString(-1))
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get value
|
||||||
|
state.GetField(-1, "value")
|
||||||
|
if state.IsString(-1) {
|
||||||
|
cookie.SetValue(state.ToString(-1))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get path
|
||||||
|
state.GetField(-1, "path")
|
||||||
|
if state.IsString(-1) {
|
||||||
|
cookie.SetPath(state.ToString(-1))
|
||||||
|
} else {
|
||||||
|
cookie.SetPath("/") // Default path
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get domain
|
||||||
|
state.GetField(-1, "domain")
|
||||||
|
if state.IsString(-1) {
|
||||||
|
cookie.SetDomain(state.ToString(-1))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get expires
|
||||||
|
state.GetField(-1, "expires")
|
||||||
|
if state.IsNumber(-1) {
|
||||||
|
expiry := int64(state.ToNumber(-1))
|
||||||
|
cookie.SetExpire(time.Unix(expiry, 0))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get max age
|
||||||
|
state.GetField(-1, "max_age")
|
||||||
|
if state.IsNumber(-1) {
|
||||||
|
cookie.SetMaxAge(int(state.ToNumber(-1)))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get secure
|
||||||
|
state.GetField(-1, "secure")
|
||||||
|
if state.IsBoolean(-1) {
|
||||||
|
cookie.SetSecure(state.ToBoolean(-1))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get http only
|
||||||
|
state.GetField(-1, "http_only")
|
||||||
|
if state.IsBoolean(-1) {
|
||||||
|
cookie.SetHTTPOnly(state.ToBoolean(-1))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
return cookie
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
package runner
|
package runner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"Moonshark/core/runner/sandbox"
|
||||||
"Moonshark/core/utils/logger"
|
"Moonshark/core/utils/logger"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -265,17 +266,21 @@ func init() {
|
||||||
GlobalRegistry.EnableDebug() // Enable debugging by default
|
GlobalRegistry.EnableDebug() // Enable debugging by default
|
||||||
logger.Debug("[ModuleRegistry] Registering core modules...")
|
logger.Debug("[ModuleRegistry] Registering core modules...")
|
||||||
|
|
||||||
GlobalRegistry.Register("util", UtilModuleInitFunc())
|
// Register core modules - these now point to the sandbox implementations
|
||||||
GlobalRegistry.Register("http", HTTPModuleInitFunc())
|
GlobalRegistry.Register("util", func(state *luajit.State) error {
|
||||||
GlobalRegistry.RegisterWithDependencies("cookie", CookieModuleInitFunc(), []string{"http"})
|
return sandbox.UtilModuleInitFunc()(state)
|
||||||
GlobalRegistry.RegisterWithDependencies("csrf", CSRFModuleInitFunc(), []string{"util"})
|
})
|
||||||
|
|
||||||
|
GlobalRegistry.Register("http", func(state *luajit.State) error {
|
||||||
|
return sandbox.HTTPModuleInitFunc()(state)
|
||||||
|
})
|
||||||
|
|
||||||
// Set explicit initialization order
|
// Set explicit initialization order
|
||||||
GlobalRegistry.SetInitOrder([]string{
|
GlobalRegistry.SetInitOrder([]string{
|
||||||
"util", // First: core utilities
|
"util", // First: core utilities
|
||||||
"http", // Second: HTTP functionality
|
"http", // Second: HTTP functionality
|
||||||
"cookie", // Third: Cookie functionality (uses HTTP)
|
"session", // Third: Session functionality
|
||||||
"csrf", // Fourth: CSRF protection (uses go and possibly session)
|
"csrf", // Fourth: CSRF protection
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.DebugCont("Core modules registered successfully")
|
logger.DebugCont("Core modules registered successfully")
|
||||||
|
|
|
@ -1,219 +0,0 @@
|
||||||
package runner
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/subtle"
|
|
||||||
|
|
||||||
"Moonshark/core/utils/logger"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
)
|
|
||||||
|
|
||||||
// LuaCSRFModule provides CSRF protection functionality to Lua scripts
|
|
||||||
const LuaCSRFModule = `
|
|
||||||
-- CSRF protection module
|
|
||||||
local csrf = {
|
|
||||||
-- Session key where the token is stored
|
|
||||||
TOKEN_KEY = "_csrf_token",
|
|
||||||
|
|
||||||
-- Default form field name
|
|
||||||
DEFAULT_FIELD = "csrf",
|
|
||||||
|
|
||||||
-- Generate a new CSRF token and store it in the session
|
|
||||||
generate = function(length)
|
|
||||||
-- Default length is 32 characters
|
|
||||||
length = length or 32
|
|
||||||
|
|
||||||
if length < 16 then
|
|
||||||
-- Enforce minimum security
|
|
||||||
length = 16
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Check if we have a session module
|
|
||||||
if not session then
|
|
||||||
error("CSRF protection requires the session module", 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
local token = util.generate_token(length)
|
|
||||||
session.set(csrf.TOKEN_KEY, token)
|
|
||||||
return token
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Get the current token or generate a new one
|
|
||||||
token = function()
|
|
||||||
-- Get from session if exists
|
|
||||||
local token = session.get(csrf.TOKEN_KEY)
|
|
||||||
|
|
||||||
-- Generate if needed
|
|
||||||
if not token then
|
|
||||||
token = csrf.generate()
|
|
||||||
end
|
|
||||||
|
|
||||||
return token
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Generate a hidden form field with the CSRF token
|
|
||||||
field = function(field_name)
|
|
||||||
field_name = field_name or csrf.DEFAULT_FIELD
|
|
||||||
local token = csrf.token()
|
|
||||||
return string.format('<input type="hidden" name="%s" value="%s">', field_name, token)
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Verify a given token against the session token
|
|
||||||
verify = function(token, field_name)
|
|
||||||
field_name = field_name or csrf.DEFAULT_FIELD
|
|
||||||
|
|
||||||
local env = getfenv(2)
|
|
||||||
|
|
||||||
local form = nil
|
|
||||||
if env.ctx and env.ctx.form then
|
|
||||||
form = env.ctx.form
|
|
||||||
else
|
|
||||||
return false
|
|
||||||
end
|
|
||||||
|
|
||||||
token = token or form[field_name]
|
|
||||||
if not token then
|
|
||||||
return false
|
|
||||||
end
|
|
||||||
|
|
||||||
local session_token = session.get(csrf.TOKEN_KEY)
|
|
||||||
if not session_token then
|
|
||||||
return false
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Constant-time comparison to prevent timing attacks
|
|
||||||
-- This is safe since Lua strings are immutable
|
|
||||||
if #token ~= #session_token then
|
|
||||||
return false
|
|
||||||
end
|
|
||||||
|
|
||||||
local result = true
|
|
||||||
for i = 1, #token do
|
|
||||||
if token:sub(i, i) ~= session_token:sub(i, i) then
|
|
||||||
result = false
|
|
||||||
-- Don't break early - continue to prevent timing attacks
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
return result
|
|
||||||
end
|
|
||||||
}
|
|
||||||
|
|
||||||
-- Install CSRF module
|
|
||||||
_G.csrf = csrf
|
|
||||||
|
|
||||||
-- Make sure the CSRF module is accessible in sandbox
|
|
||||||
if __env_system and __env_system.base_env then
|
|
||||||
__env_system.base_env.csrf = csrf
|
|
||||||
end
|
|
||||||
`
|
|
||||||
|
|
||||||
// CSRFModuleInitFunc returns an initializer for the CSRF module
|
|
||||||
func CSRFModuleInitFunc() StateInitFunc {
|
|
||||||
return func(state *luajit.State) error {
|
|
||||||
return state.DoString(LuaCSRFModule)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateCSRFToken checks if the CSRF token is valid for a request
|
|
||||||
func ValidateCSRFToken(state *luajit.State, ctx *Context) bool {
|
|
||||||
// Only validate for form submissions
|
|
||||||
method, ok := ctx.Get("method").(string)
|
|
||||||
if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get form data
|
|
||||||
formData, ok := ctx.Get("form").(map[string]any)
|
|
||||||
if !ok || formData == nil {
|
|
||||||
logger.Warning("CSRF validation failed: no form data")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get token from form
|
|
||||||
formToken, ok := formData["csrf"].(string)
|
|
||||||
if !ok || formToken == "" {
|
|
||||||
logger.Warning("CSRF validation failed: no token in form")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get session token
|
|
||||||
state.GetGlobal("session")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
logger.Warning("CSRF validation failed: session module not available")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
state.GetField(-1, "get")
|
|
||||||
if !state.IsFunction(-1) {
|
|
||||||
state.Pop(2)
|
|
||||||
logger.Warning("CSRF validation failed: session.get not available")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
state.PushCopy(-1) // Duplicate function
|
|
||||||
state.PushString("_csrf_token")
|
|
||||||
|
|
||||||
if err := state.Call(1, 1); err != nil {
|
|
||||||
state.Pop(3) // Pop error, function and session table
|
|
||||||
logger.Warning("CSRF validation failed: %v", err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(3) // Pop nil, function and session table
|
|
||||||
logger.Warning("CSRF validation failed: no token in session")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionToken := state.ToString(-1)
|
|
||||||
state.Pop(3) // Pop token, function and session table
|
|
||||||
|
|
||||||
// Constant-time comparison to prevent timing attacks
|
|
||||||
return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithCSRFProtection creates a runner option to add CSRF protection
|
|
||||||
func WithCSRFProtection() RunnerOption {
|
|
||||||
return func(r *Runner) {
|
|
||||||
r.AddInitHook(func(state *luajit.State, ctx *Context) error {
|
|
||||||
// Get request method
|
|
||||||
method, ok := ctx.Get("method").(string)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only validate for form submissions
|
|
||||||
if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for form data
|
|
||||||
form, ok := ctx.Get("form").(map[string]any)
|
|
||||||
if !ok || form == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate CSRF token
|
|
||||||
if !ValidateCSRFToken(state, ctx) {
|
|
||||||
return ErrCSRFValidationFailed
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error for CSRF validation failure
|
|
||||||
var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"}
|
|
||||||
|
|
||||||
// CSRFError represents a CSRF validation error
|
|
||||||
type CSRFError struct {
|
|
||||||
message string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error implements the error interface
|
|
||||||
func (e *CSRFError) Error() string {
|
|
||||||
return e.message
|
|
||||||
}
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/panjf2000/ants/v2"
|
"github.com/panjf2000/ants/v2"
|
||||||
"github.com/valyala/bytebufferpool"
|
"github.com/valyala/bytebufferpool"
|
||||||
|
|
||||||
|
"Moonshark/core/runner/sandbox"
|
||||||
"Moonshark/core/utils/logger"
|
"Moonshark/core/utils/logger"
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
|
@ -30,11 +31,11 @@ type RunnerOption func(*Runner)
|
||||||
|
|
||||||
// State wraps a Lua state with its sandbox
|
// State wraps a Lua state with its sandbox
|
||||||
type State struct {
|
type State struct {
|
||||||
L *luajit.State // The Lua state
|
L *luajit.State // The Lua state
|
||||||
sandbox *Sandbox // Associated sandbox
|
sandbox *sandbox.Sandbox // Associated sandbox
|
||||||
index int // Index for debugging
|
index int // Index for debugging
|
||||||
inUse bool // Whether the state is currently in use
|
inUse bool // Whether the state is currently in use
|
||||||
initTime time.Time // When this state was initialized
|
initTime time.Time // When this state was initialized
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitHook runs before executing a script
|
// InitHook runs before executing a script
|
||||||
|
@ -217,7 +218,7 @@ func (r *Runner) createState(index int) (*State, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create sandbox
|
// Create sandbox
|
||||||
sandbox := NewSandbox()
|
sandbox := sandbox.NewSandbox()
|
||||||
if r.debug && verbose {
|
if r.debug && verbose {
|
||||||
sandbox.EnableDebug()
|
sandbox.EnableDebug()
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,177 +0,0 @@
|
||||||
package runner
|
|
||||||
|
|
||||||
import (
|
|
||||||
"Moonshark/core/utils/logger"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
)
|
|
||||||
|
|
||||||
// LuaSessionModule provides session functionality to Lua scripts
|
|
||||||
const LuaSessionModule = `
|
|
||||||
-- Global table to store session data
|
|
||||||
__session_data = __session_data or {}
|
|
||||||
__session_id = __session_id or nil
|
|
||||||
__session_modified = false
|
|
||||||
|
|
||||||
-- Session module implementation
|
|
||||||
local session = {
|
|
||||||
-- Get a session value
|
|
||||||
get = function(key)
|
|
||||||
if type(key) ~= "string" then
|
|
||||||
error("session.get: key must be a string", 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
if __session_data and __session_data[key] then
|
|
||||||
return __session_data[key]
|
|
||||||
end
|
|
||||||
|
|
||||||
return nil
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Set a session value
|
|
||||||
set = function(key, value)
|
|
||||||
if type(key) ~= "string" then
|
|
||||||
error("session.set: key must be a string", 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Ensure session data table exists
|
|
||||||
__session_data = __session_data or {}
|
|
||||||
|
|
||||||
-- Store value
|
|
||||||
__session_data[key] = value
|
|
||||||
|
|
||||||
-- Mark session as modified
|
|
||||||
__session_modified = true
|
|
||||||
|
|
||||||
return true
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Delete a session value
|
|
||||||
delete = function(key)
|
|
||||||
if type(key) ~= "string" then
|
|
||||||
error("session.delete: key must be a string", 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
if __session_data then
|
|
||||||
__session_data[key] = nil
|
|
||||||
__session_modified = true
|
|
||||||
end
|
|
||||||
|
|
||||||
return true
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Clear all session data
|
|
||||||
clear = function()
|
|
||||||
__session_data = {}
|
|
||||||
__session_modified = true
|
|
||||||
return true
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Get the session ID
|
|
||||||
get_id = function()
|
|
||||||
return __session_id or nil
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Get all session data
|
|
||||||
get_all = function()
|
|
||||||
local result = {}
|
|
||||||
for k, v in pairs(__session_data or {}) do
|
|
||||||
result[k] = v
|
|
||||||
end
|
|
||||||
return result
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Check if session has a key
|
|
||||||
has = function(key)
|
|
||||||
if type(key) ~= "string" then
|
|
||||||
error("session.has: key must be a string", 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
return __session_data and __session_data[key] ~= nil
|
|
||||||
end
|
|
||||||
}
|
|
||||||
|
|
||||||
-- Install session module
|
|
||||||
_G.session = session
|
|
||||||
|
|
||||||
-- Make sure the session module is accessible in sandbox
|
|
||||||
if __env_system and __env_system.base_env then
|
|
||||||
__env_system.base_env.session = session
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Hook into script execution to preserve session state
|
|
||||||
local old_execute_script = __execute_script
|
|
||||||
if old_execute_script then
|
|
||||||
__execute_script = function(fn, ctx)
|
|
||||||
-- Reset modification flag at the start of request
|
|
||||||
__session_modified = false
|
|
||||||
|
|
||||||
-- Execute original function
|
|
||||||
return old_execute_script(fn, ctx)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
`
|
|
||||||
|
|
||||||
// GetSessionData extracts session data from Lua state
|
|
||||||
func GetSessionData(state *luajit.State) (string, map[string]any, bool) {
|
|
||||||
// Check if session was modified
|
|
||||||
state.GetGlobal("__session_modified")
|
|
||||||
modified := state.ToBoolean(-1)
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
if !modified {
|
|
||||||
return "", nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get session ID
|
|
||||||
state.GetGlobal("__session_id")
|
|
||||||
sessionID := state.ToString(-1)
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get session data
|
|
||||||
state.GetGlobal("__session_data")
|
|
||||||
if !state.IsTable(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
return sessionID, nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := state.ToTable(-1)
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to extract session data: %v", err)
|
|
||||||
return sessionID, nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
return sessionID, data, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSessionData sets session data in Lua state
|
|
||||||
func SetSessionData(state *luajit.State, sessionID string, data map[string]any) error {
|
|
||||||
// Set session ID
|
|
||||||
state.PushString(sessionID)
|
|
||||||
state.SetGlobal("__session_id")
|
|
||||||
|
|
||||||
// Set session data
|
|
||||||
if data == nil {
|
|
||||||
data = make(map[string]any)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := state.PushTable(data); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
state.SetGlobal("__session_data")
|
|
||||||
|
|
||||||
// Reset modification flag
|
|
||||||
state.PushBoolean(false)
|
|
||||||
state.SetGlobal("__session_modified")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SessionModuleInitFunc returns an initializer for the session module
|
|
||||||
func SessionModuleInitFunc() StateInitFunc {
|
|
||||||
return func(state *luajit.State) error {
|
|
||||||
return state.DoString(LuaSessionModule)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -3,6 +3,7 @@ package runner
|
||||||
import (
|
import (
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
|
"Moonshark/core/runner/sandbox"
|
||||||
"Moonshark/core/sessions"
|
"Moonshark/core/sessions"
|
||||||
"Moonshark/core/utils/logger"
|
"Moonshark/core/utils/logger"
|
||||||
|
|
||||||
|
@ -40,9 +41,6 @@ func WithSessionManager(manager *sessions.SessionManager) RunnerOption {
|
||||||
return func(r *Runner) {
|
return func(r *Runner) {
|
||||||
handler := NewSessionHandler(manager)
|
handler := NewSessionHandler(manager)
|
||||||
|
|
||||||
// Register the session module
|
|
||||||
RegisterCoreModule("session", SessionModuleInitFunc())
|
|
||||||
|
|
||||||
// Add hooks to the runner
|
// Add hooks to the runner
|
||||||
r.AddInitHook(handler.preRequestHook)
|
r.AddInitHook(handler.preRequestHook)
|
||||||
r.AddFinalizeHook(handler.postRequestHook)
|
r.AddFinalizeHook(handler.postRequestHook)
|
||||||
|
@ -140,8 +138,10 @@ func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, resu
|
||||||
session.Set(k, v)
|
session.Set(k, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h.manager.SaveSession(session)
|
||||||
|
|
||||||
// Add session cookie to result if it's an HTTP response
|
// Add session cookie to result if it's an HTTP response
|
||||||
if httpResp, ok := result.(*HTTPResponse); ok {
|
if httpResp, ok := result.(*sandbox.HTTPResponse); ok {
|
||||||
h.addSessionCookie(httpResp, modifiedID)
|
h.addSessionCookie(httpResp, modifiedID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -150,7 +150,7 @@ func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, resu
|
||||||
}
|
}
|
||||||
|
|
||||||
// addSessionCookie adds a session cookie to an HTTP response
|
// addSessionCookie adds a session cookie to an HTTP response
|
||||||
func (h *SessionHandler) addSessionCookie(resp *HTTPResponse, sessionID string) {
|
func (h *SessionHandler) addSessionCookie(resp *sandbox.HTTPResponse, sessionID string) {
|
||||||
// Get cookie options
|
// Get cookie options
|
||||||
opts := h.manager.CookieOptions()
|
opts := h.manager.CookieOptions()
|
||||||
|
|
||||||
|
@ -184,3 +184,60 @@ func (h *SessionHandler) addSessionCookie(resp *HTTPResponse, sessionID string)
|
||||||
|
|
||||||
resp.Cookies = append(resp.Cookies, cookie)
|
resp.Cookies = append(resp.Cookies, cookie)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSessionData extracts session data from Lua state
|
||||||
|
func GetSessionData(state *luajit.State) (string, map[string]any, bool) {
|
||||||
|
// Check if session was modified
|
||||||
|
state.GetGlobal("__session_modified")
|
||||||
|
modified := state.ToBoolean(-1)
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
if !modified {
|
||||||
|
return "", nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get session ID
|
||||||
|
state.GetGlobal("__session_id")
|
||||||
|
sessionID := state.ToString(-1)
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get session data
|
||||||
|
state.GetGlobal("__session_data")
|
||||||
|
if !state.IsTable(-1) {
|
||||||
|
state.Pop(1)
|
||||||
|
return sessionID, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := state.ToTable(-1)
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to extract session data: %v", err)
|
||||||
|
return sessionID, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return sessionID, data, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSessionData sets session data in Lua state
|
||||||
|
func SetSessionData(state *luajit.State, sessionID string, data map[string]any) error {
|
||||||
|
// Set session ID
|
||||||
|
state.PushString(sessionID)
|
||||||
|
state.SetGlobal("__session_id")
|
||||||
|
|
||||||
|
// Set session data
|
||||||
|
if data == nil {
|
||||||
|
data = make(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := state.PushTable(data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
state.SetGlobal("__session_data")
|
||||||
|
|
||||||
|
// Reset modification flag
|
||||||
|
state.PushBoolean(false)
|
||||||
|
state.SetGlobal("__session_modified")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
98
core/runner/sandbox/Embed.go
Normal file
98
core/runner/sandbox/Embed.go
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
package sandbox
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
|
||||||
|
"Moonshark/core/utils/logger"
|
||||||
|
|
||||||
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed lua/sandbox.lua
|
||||||
|
var sandboxLua string
|
||||||
|
|
||||||
|
// InitializeSandbox loads the embedded Lua sandbox code into a Lua state
|
||||||
|
func InitializeSandbox(state *luajit.State) error {
|
||||||
|
// Compile once, use many times
|
||||||
|
bytecodeOnce.Do(precompileSandbox)
|
||||||
|
|
||||||
|
if sandboxBytecode != nil {
|
||||||
|
logger.Debug("Loading sandbox.lua from precompiled bytecode")
|
||||||
|
return state.LoadAndRunBytecode(sandboxBytecode, "sandbox.lua")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback if compilation failed
|
||||||
|
logger.Warning("Using non-precompiled sandbox.lua (bytecode compilation failed)")
|
||||||
|
return state.DoString(sandboxLua)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModuleInitializers stores initializer functions for core modules
|
||||||
|
type ModuleInitializers struct {
|
||||||
|
HTTP func(*luajit.State) error
|
||||||
|
Util func(*luajit.State) error
|
||||||
|
Session func(*luajit.State) error
|
||||||
|
Cookie func(*luajit.State) error
|
||||||
|
CSRF func(*luajit.State) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultInitializers returns the default set of initializers
|
||||||
|
func DefaultInitializers() *ModuleInitializers {
|
||||||
|
return &ModuleInitializers{
|
||||||
|
HTTP: func(state *luajit.State) error {
|
||||||
|
// Register the native Go function first
|
||||||
|
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
|
||||||
|
logger.Error("[HTTP Module] Failed to register __http_request function: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
Util: func(state *luajit.State) error {
|
||||||
|
// Register util functions
|
||||||
|
return RegisterModule(state, "util", UtilModuleFunctions())
|
||||||
|
},
|
||||||
|
Session: func(state *luajit.State) error {
|
||||||
|
// Session doesn't need special initialization
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
Cookie: func(state *luajit.State) error {
|
||||||
|
// Cookie doesn't need special initialization
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
CSRF: func(state *luajit.State) error {
|
||||||
|
// CSRF doesn't need special initialization
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitializeAll initializes all modules in the Lua state
|
||||||
|
func InitializeAll(state *luajit.State, initializers *ModuleInitializers) error {
|
||||||
|
// Set up dependencies first
|
||||||
|
if err := initializers.Util(state); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := initializers.HTTP(state); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the embedded sandbox code
|
||||||
|
if err := InitializeSandbox(state); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the rest of the modules
|
||||||
|
if err := initializers.Session(state); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := initializers.Cookie(state); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := initializers.CSRF(state); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package runner
|
package sandbox
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -28,7 +28,7 @@ type HTTPResponse struct {
|
||||||
|
|
||||||
// Response pool to reduce allocations
|
// Response pool to reduce allocations
|
||||||
var responsePool = sync.Pool{
|
var responsePool = sync.Pool{
|
||||||
New: func() interface{} {
|
New: func() any {
|
||||||
return &HTTPResponse{
|
return &HTTPResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Headers: make(map[string]string, 8), // Pre-allocate with reasonable capacity
|
Headers: make(map[string]string, 8), // Pre-allocate with reasonable capacity
|
||||||
|
@ -37,36 +37,6 @@ var responsePool = sync.Pool{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHTTPResponse creates a default HTTP response, potentially reusing one from the pool
|
|
||||||
func NewHTTPResponse() *HTTPResponse {
|
|
||||||
return responsePool.Get().(*HTTPResponse)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReleaseResponse returns the response to the pool after clearing its values
|
|
||||||
func ReleaseResponse(resp *HTTPResponse) {
|
|
||||||
if resp == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear all values to prevent data leakage
|
|
||||||
resp.Status = 200 // Reset to default
|
|
||||||
|
|
||||||
// Clear headers
|
|
||||||
for k := range resp.Headers {
|
|
||||||
delete(resp.Headers, k)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear cookies
|
|
||||||
resp.Cookies = resp.Cookies[:0] // Keep capacity but set length to 0
|
|
||||||
|
|
||||||
// Clear body
|
|
||||||
resp.Body = nil
|
|
||||||
|
|
||||||
responsePool.Put(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------- HTTP CLIENT FUNCTIONALITY ----------
|
|
||||||
|
|
||||||
// Default HTTP client with sensible timeout
|
// Default HTTP client with sensible timeout
|
||||||
var defaultFastClient fasthttp.Client = fasthttp.Client{
|
var defaultFastClient fasthttp.Client = fasthttp.Client{
|
||||||
MaxConnsPerHost: 1024,
|
MaxConnsPerHost: 1024,
|
||||||
|
@ -96,8 +66,256 @@ var DefaultHTTPClientConfig = HTTPClientConfig{
|
||||||
AllowRemote: true,
|
AllowRemote: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function name constant to ensure consistency
|
// NewHTTPResponse creates a default HTTP response, potentially reusing one from the pool
|
||||||
const httpRequestFuncName = "__http_request"
|
func NewHTTPResponse() *HTTPResponse {
|
||||||
|
return responsePool.Get().(*HTTPResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReleaseResponse returns the response to the pool after clearing its values
|
||||||
|
func ReleaseResponse(resp *HTTPResponse) {
|
||||||
|
if resp == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear all values to prevent data leakage
|
||||||
|
resp.Status = 200 // Reset to default
|
||||||
|
|
||||||
|
// Clear headers
|
||||||
|
for k := range resp.Headers {
|
||||||
|
delete(resp.Headers, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear cookies
|
||||||
|
resp.Cookies = resp.Cookies[:0] // Keep capacity but set length to 0
|
||||||
|
|
||||||
|
// Clear body
|
||||||
|
resp.Body = nil
|
||||||
|
|
||||||
|
responsePool.Put(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPModuleInitFunc returns an initializer function for the HTTP module
|
||||||
|
func HTTPModuleInitFunc() func(*luajit.State) error {
|
||||||
|
return func(state *luajit.State) error {
|
||||||
|
// Register the native Go function first
|
||||||
|
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
|
||||||
|
logger.Error("[HTTP Module] Failed to register __http_request function")
|
||||||
|
logger.ErrorCont("%v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up default HTTP client configuration
|
||||||
|
setupHTTPClientConfig(state)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to set up HTTP client config
|
||||||
|
func setupHTTPClientConfig(state *luajit.State) {
|
||||||
|
state.NewTable()
|
||||||
|
|
||||||
|
state.PushNumber(float64(DefaultHTTPClientConfig.MaxTimeout / time.Second))
|
||||||
|
state.SetField(-2, "max_timeout")
|
||||||
|
|
||||||
|
state.PushNumber(float64(DefaultHTTPClientConfig.DefaultTimeout / time.Second))
|
||||||
|
state.SetField(-2, "default_timeout")
|
||||||
|
|
||||||
|
state.PushNumber(float64(DefaultHTTPClientConfig.MaxResponseSize))
|
||||||
|
state.SetField(-2, "max_response_size")
|
||||||
|
|
||||||
|
state.PushBoolean(DefaultHTTPClientConfig.AllowRemote)
|
||||||
|
state.SetField(-2, "allow_remote")
|
||||||
|
|
||||||
|
state.SetGlobal("__http_client_config")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHTTPResponse extracts the HTTP response from Lua state
|
||||||
|
func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) {
|
||||||
|
response := NewHTTPResponse()
|
||||||
|
|
||||||
|
// Get response table
|
||||||
|
state.GetGlobal("__http_responses")
|
||||||
|
if state.IsNil(-1) {
|
||||||
|
state.Pop(1)
|
||||||
|
ReleaseResponse(response) // Return unused response to pool
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for response at thread index
|
||||||
|
state.PushNumber(1)
|
||||||
|
state.GetTable(-2)
|
||||||
|
if state.IsNil(-1) {
|
||||||
|
state.Pop(2)
|
||||||
|
ReleaseResponse(response) // Return unused response to pool
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get status
|
||||||
|
state.GetField(-1, "status")
|
||||||
|
if state.IsNumber(-1) {
|
||||||
|
response.Status = int(state.ToNumber(-1))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get headers
|
||||||
|
state.GetField(-1, "headers")
|
||||||
|
if state.IsTable(-1) {
|
||||||
|
// Iterate through headers table
|
||||||
|
state.PushNil() // Start iteration
|
||||||
|
for state.Next(-2) {
|
||||||
|
// Stack has key at -2 and value at -1
|
||||||
|
if state.IsString(-2) && state.IsString(-1) {
|
||||||
|
key := state.ToString(-2)
|
||||||
|
value := state.ToString(-1)
|
||||||
|
response.Headers[key] = value
|
||||||
|
}
|
||||||
|
state.Pop(1) // Pop value, leave key for next iteration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get cookies
|
||||||
|
state.GetField(-1, "cookies")
|
||||||
|
if state.IsTable(-1) {
|
||||||
|
// Iterate through cookies array
|
||||||
|
length := state.GetTableLength(-1)
|
||||||
|
for i := 1; i <= length; i++ {
|
||||||
|
state.PushNumber(float64(i))
|
||||||
|
state.GetTable(-2)
|
||||||
|
|
||||||
|
if state.IsTable(-1) {
|
||||||
|
cookie := extractCookie(state)
|
||||||
|
if cookie != nil {
|
||||||
|
response.Cookies = append(response.Cookies, cookie)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
state.Pop(2) // Pop response table and __http_responses
|
||||||
|
|
||||||
|
return response, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyHTTPResponse applies an HTTP response to a fasthttp.RequestCtx
|
||||||
|
func ApplyHTTPResponse(httpResp *HTTPResponse, ctx *fasthttp.RequestCtx) {
|
||||||
|
// Set status code
|
||||||
|
ctx.SetStatusCode(httpResp.Status)
|
||||||
|
|
||||||
|
// Set headers
|
||||||
|
for name, value := range httpResp.Headers {
|
||||||
|
ctx.Response.Header.Set(name, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set cookies
|
||||||
|
for _, cookie := range httpResp.Cookies {
|
||||||
|
ctx.Response.Header.SetCookie(cookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the body based on its type
|
||||||
|
if httpResp.Body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set body based on type
|
||||||
|
switch body := httpResp.Body.(type) {
|
||||||
|
case string:
|
||||||
|
ctx.SetBodyString(body)
|
||||||
|
case []byte:
|
||||||
|
ctx.SetBody(body)
|
||||||
|
case map[string]any, []any, []float64, []string, []int:
|
||||||
|
// Marshal JSON using a buffer from the pool
|
||||||
|
buf := bytebufferpool.Get()
|
||||||
|
defer bytebufferpool.Put(buf)
|
||||||
|
|
||||||
|
if err := json.NewEncoder(buf).Encode(body); err == nil {
|
||||||
|
// Set content type if not already set
|
||||||
|
if len(ctx.Response.Header.ContentType()) == 0 {
|
||||||
|
ctx.Response.Header.SetContentType("application/json")
|
||||||
|
}
|
||||||
|
ctx.SetBody(buf.Bytes())
|
||||||
|
} else {
|
||||||
|
// Fallback
|
||||||
|
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Default to string representation
|
||||||
|
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractCookie grabs cookies from the Lua state
|
||||||
|
func extractCookie(state *luajit.State) *fasthttp.Cookie {
|
||||||
|
cookie := fasthttp.AcquireCookie()
|
||||||
|
|
||||||
|
// Get name
|
||||||
|
state.GetField(-1, "name")
|
||||||
|
if !state.IsString(-1) {
|
||||||
|
state.Pop(1)
|
||||||
|
fasthttp.ReleaseCookie(cookie)
|
||||||
|
return nil // Name is required
|
||||||
|
}
|
||||||
|
cookie.SetKey(state.ToString(-1))
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get value
|
||||||
|
state.GetField(-1, "value")
|
||||||
|
if state.IsString(-1) {
|
||||||
|
cookie.SetValue(state.ToString(-1))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get path
|
||||||
|
state.GetField(-1, "path")
|
||||||
|
if state.IsString(-1) {
|
||||||
|
cookie.SetPath(state.ToString(-1))
|
||||||
|
} else {
|
||||||
|
cookie.SetPath("/") // Default path
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get domain
|
||||||
|
state.GetField(-1, "domain")
|
||||||
|
if state.IsString(-1) {
|
||||||
|
cookie.SetDomain(state.ToString(-1))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get expires
|
||||||
|
state.GetField(-1, "expires")
|
||||||
|
if state.IsNumber(-1) {
|
||||||
|
expiry := int64(state.ToNumber(-1))
|
||||||
|
cookie.SetExpire(time.Unix(expiry, 0))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get max age
|
||||||
|
state.GetField(-1, "max_age")
|
||||||
|
if state.IsNumber(-1) {
|
||||||
|
cookie.SetMaxAge(int(state.ToNumber(-1)))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get secure
|
||||||
|
state.GetField(-1, "secure")
|
||||||
|
if state.IsBoolean(-1) {
|
||||||
|
cookie.SetSecure(state.ToBoolean(-1))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get http only
|
||||||
|
state.GetField(-1, "http_only")
|
||||||
|
if state.IsBoolean(-1) {
|
||||||
|
cookie.SetHTTPOnly(state.ToBoolean(-1))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
return cookie
|
||||||
|
}
|
||||||
|
|
||||||
// httpRequest makes an HTTP request and returns the result to Lua
|
// httpRequest makes an HTTP request and returns the result to Lua
|
||||||
func httpRequest(state *luajit.State) int {
|
func httpRequest(state *luajit.State) int {
|
||||||
|
@ -360,372 +578,3 @@ func httpRequest(state *luajit.State) int {
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTPModuleInitFunc returns an initializer function for the HTTP module
|
|
||||||
func HTTPModuleInitFunc() StateInitFunc {
|
|
||||||
return func(state *luajit.State) error {
|
|
||||||
// CRITICAL: Register the native Go function first
|
|
||||||
// This must be done BEFORE any Lua code that references it
|
|
||||||
if err := state.RegisterGoFunction(httpRequestFuncName, httpRequest); err != nil {
|
|
||||||
logger.Error("[HTTP Module] Failed to register __http_request function")
|
|
||||||
logger.ErrorCont("%v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up default HTTP client configuration
|
|
||||||
setupHTTPClientConfig(state)
|
|
||||||
|
|
||||||
// Initialize Lua HTTP module
|
|
||||||
if err := state.DoString(LuaHTTPModule); err != nil {
|
|
||||||
logger.Error("[HTTP Module] Failed to initialize HTTP module Lua code")
|
|
||||||
logger.ErrorCont("%v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify HTTP client functions are available
|
|
||||||
verifyHTTPClient(state)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper to set up HTTP client config
|
|
||||||
func setupHTTPClientConfig(state *luajit.State) {
|
|
||||||
state.NewTable()
|
|
||||||
|
|
||||||
state.PushNumber(float64(DefaultHTTPClientConfig.MaxTimeout / time.Second))
|
|
||||||
state.SetField(-2, "max_timeout")
|
|
||||||
|
|
||||||
state.PushNumber(float64(DefaultHTTPClientConfig.DefaultTimeout / time.Second))
|
|
||||||
state.SetField(-2, "default_timeout")
|
|
||||||
|
|
||||||
state.PushNumber(float64(DefaultHTTPClientConfig.MaxResponseSize))
|
|
||||||
state.SetField(-2, "max_response_size")
|
|
||||||
|
|
||||||
state.PushBoolean(DefaultHTTPClientConfig.AllowRemote)
|
|
||||||
state.SetField(-2, "allow_remote")
|
|
||||||
|
|
||||||
state.SetGlobal("__http_client_config")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetHTTPResponse extracts the HTTP response from Lua state
|
|
||||||
func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) {
|
|
||||||
response := NewHTTPResponse()
|
|
||||||
|
|
||||||
// Get response table
|
|
||||||
state.GetGlobal("__http_responses")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
ReleaseResponse(response) // Return unused response to pool
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for response at thread index
|
|
||||||
state.PushNumber(1)
|
|
||||||
state.GetTable(-2)
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(2)
|
|
||||||
ReleaseResponse(response) // Return unused response to pool
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get status
|
|
||||||
state.GetField(-1, "status")
|
|
||||||
if state.IsNumber(-1) {
|
|
||||||
response.Status = int(state.ToNumber(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get headers
|
|
||||||
state.GetField(-1, "headers")
|
|
||||||
if state.IsTable(-1) {
|
|
||||||
// Iterate through headers table
|
|
||||||
state.PushNil() // Start iteration
|
|
||||||
for state.Next(-2) {
|
|
||||||
// Stack has key at -2 and value at -1
|
|
||||||
if state.IsString(-2) && state.IsString(-1) {
|
|
||||||
key := state.ToString(-2)
|
|
||||||
value := state.ToString(-1)
|
|
||||||
response.Headers[key] = value
|
|
||||||
}
|
|
||||||
state.Pop(1) // Pop value, leave key for next iteration
|
|
||||||
}
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get cookies
|
|
||||||
state.GetField(-1, "cookies")
|
|
||||||
if state.IsTable(-1) {
|
|
||||||
// Iterate through cookies array
|
|
||||||
length := state.GetTableLength(-1)
|
|
||||||
for i := 1; i <= length; i++ {
|
|
||||||
state.PushNumber(float64(i))
|
|
||||||
state.GetTable(-2)
|
|
||||||
|
|
||||||
if state.IsTable(-1) {
|
|
||||||
cookie := extractCookie(state)
|
|
||||||
if cookie != nil {
|
|
||||||
response.Cookies = append(response.Cookies, cookie)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Clean up
|
|
||||||
state.Pop(2) // Pop response table and __http_responses
|
|
||||||
|
|
||||||
return response, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ApplyHTTPResponse applies an HTTP response to a fasthttp.RequestCtx
|
|
||||||
func ApplyHTTPResponse(httpResp *HTTPResponse, ctx *fasthttp.RequestCtx) {
|
|
||||||
// Set status code
|
|
||||||
ctx.SetStatusCode(httpResp.Status)
|
|
||||||
|
|
||||||
// Set headers
|
|
||||||
for name, value := range httpResp.Headers {
|
|
||||||
ctx.Response.Header.Set(name, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set cookies
|
|
||||||
for _, cookie := range httpResp.Cookies {
|
|
||||||
ctx.Response.Header.SetCookie(cookie)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process the body based on its type
|
|
||||||
if httpResp.Body == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set body based on type
|
|
||||||
switch body := httpResp.Body.(type) {
|
|
||||||
case string:
|
|
||||||
ctx.SetBodyString(body)
|
|
||||||
case []byte:
|
|
||||||
ctx.SetBody(body)
|
|
||||||
case map[string]any, []any, []float64, []string, []int:
|
|
||||||
// Marshal JSON using a buffer from the pool
|
|
||||||
buf := bytebufferpool.Get()
|
|
||||||
defer bytebufferpool.Put(buf)
|
|
||||||
|
|
||||||
if err := json.NewEncoder(buf).Encode(body); err == nil {
|
|
||||||
// Set content type if not already set
|
|
||||||
if len(ctx.Response.Header.ContentType()) == 0 {
|
|
||||||
ctx.Response.Header.SetContentType("application/json")
|
|
||||||
}
|
|
||||||
ctx.SetBody(buf.Bytes())
|
|
||||||
} else {
|
|
||||||
// Fallback
|
|
||||||
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
// Default to string representation
|
|
||||||
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithHTTPClientConfig creates a runner option to configure the HTTP client
|
|
||||||
func WithHTTPClientConfig(config HTTPClientConfig) RunnerOption {
|
|
||||||
return func(r *Runner) {
|
|
||||||
// Store the config to be applied during initialization
|
|
||||||
r.AddModule("__http_client_config", map[string]any{
|
|
||||||
"max_timeout": float64(config.MaxTimeout / time.Second),
|
|
||||||
"default_timeout": float64(config.DefaultTimeout / time.Second),
|
|
||||||
"max_response_size": float64(config.MaxResponseSize),
|
|
||||||
"allow_remote": config.AllowRemote,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RestrictHTTPToLocalhost is a convenience function to restrict HTTP client
|
|
||||||
// to localhost connections only
|
|
||||||
func RestrictHTTPToLocalhost() RunnerOption {
|
|
||||||
return WithHTTPClientConfig(HTTPClientConfig{
|
|
||||||
MaxTimeout: DefaultHTTPClientConfig.MaxTimeout,
|
|
||||||
DefaultTimeout: DefaultHTTPClientConfig.DefaultTimeout,
|
|
||||||
MaxResponseSize: DefaultHTTPClientConfig.MaxResponseSize,
|
|
||||||
AllowRemote: false,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify that HTTP client is properly set up
|
|
||||||
func verifyHTTPClient(state *luajit.State) {
|
|
||||||
// Get the client table
|
|
||||||
state.GetGlobal("http")
|
|
||||||
if !state.IsTable(-1) {
|
|
||||||
logger.Warning("[HTTP Module] 'http' is not a table")
|
|
||||||
state.Pop(1)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
state.GetField(-1, "client")
|
|
||||||
if !state.IsTable(-1) {
|
|
||||||
logger.Warning("[HTTP Module] 'http.client' is not a table")
|
|
||||||
state.Pop(2)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for get function
|
|
||||||
state.GetField(-1, "get")
|
|
||||||
if !state.IsFunction(-1) {
|
|
||||||
logger.Warning("[HTTP Module] 'http.client.get' is not a function")
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Check for the request function
|
|
||||||
state.GetField(-1, "request")
|
|
||||||
if !state.IsFunction(-1) {
|
|
||||||
logger.Warning("[HTTP Module] 'http.client.request' is not a function")
|
|
||||||
}
|
|
||||||
state.Pop(3) // Pop request, client, http
|
|
||||||
}
|
|
||||||
|
|
||||||
const LuaHTTPModule = `
|
|
||||||
-- Table to store response data
|
|
||||||
__http_responses = {}
|
|
||||||
|
|
||||||
-- HTTP module implementation
|
|
||||||
local http = {
|
|
||||||
-- Set HTTP status code
|
|
||||||
set_status = function(code)
|
|
||||||
if type(code) ~= "number" then
|
|
||||||
error("http.set_status: status code must be a number", 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
local resp = __http_responses[1] or {}
|
|
||||||
resp.status = code
|
|
||||||
__http_responses[1] = resp
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Set HTTP header
|
|
||||||
set_header = function(name, value)
|
|
||||||
if type(name) ~= "string" or type(value) ~= "string" then
|
|
||||||
error("http.set_header: name and value must be strings", 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
local resp = __http_responses[1] or {}
|
|
||||||
resp.headers = resp.headers or {}
|
|
||||||
resp.headers[name] = value
|
|
||||||
__http_responses[1] = resp
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Set content type; set_header helper
|
|
||||||
set_content_type = function(content_type)
|
|
||||||
http.set_header("Content-Type", content_type)
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- HTTP client submodule
|
|
||||||
client = {
|
|
||||||
-- Generic request function
|
|
||||||
request = function(method, url, body, options)
|
|
||||||
if type(method) ~= "string" then
|
|
||||||
error("http.client.request: method must be a string", 2)
|
|
||||||
end
|
|
||||||
if type(url) ~= "string" then
|
|
||||||
error("http.client.request: url must be a string", 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Call native implementation (this is the critical part)
|
|
||||||
local result = __http_request(method, url, body, options)
|
|
||||||
return result
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Simple GET request
|
|
||||||
get = function(url, options)
|
|
||||||
return http.client.request("GET", url, nil, options)
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Simple POST request with automatic content-type
|
|
||||||
post = function(url, body, options)
|
|
||||||
options = options or {}
|
|
||||||
return http.client.request("POST", url, body, options)
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Simple PUT request with automatic content-type
|
|
||||||
put = function(url, body, options)
|
|
||||||
options = options or {}
|
|
||||||
return http.client.request("PUT", url, body, options)
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Simple DELETE request
|
|
||||||
delete = function(url, options)
|
|
||||||
return http.client.request("DELETE", url, nil, options)
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Simple PATCH request
|
|
||||||
patch = function(url, body, options)
|
|
||||||
options = options or {}
|
|
||||||
return http.client.request("PATCH", url, body, options)
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Simple HEAD request
|
|
||||||
head = function(url, options)
|
|
||||||
options = options or {}
|
|
||||||
local old_options = options
|
|
||||||
options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query}
|
|
||||||
local response = http.client.request("HEAD", url, nil, options)
|
|
||||||
return response
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Simple OPTIONS request
|
|
||||||
options = function(url, options)
|
|
||||||
return http.client.request("OPTIONS", url, nil, options)
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Shorthand function to directly get JSON
|
|
||||||
get_json = function(url, options)
|
|
||||||
options = options or {}
|
|
||||||
local response = http.client.get(url, options)
|
|
||||||
if response.ok and response.json then
|
|
||||||
return response.json
|
|
||||||
end
|
|
||||||
return nil, response
|
|
||||||
end,
|
|
||||||
|
|
||||||
-- Utility to build a URL with query parameters
|
|
||||||
build_url = function(base_url, params)
|
|
||||||
if not params or type(params) ~= "table" then
|
|
||||||
return base_url
|
|
||||||
end
|
|
||||||
|
|
||||||
local query = {}
|
|
||||||
for k, v in pairs(params) do
|
|
||||||
if type(v) == "table" then
|
|
||||||
for _, item in ipairs(v) do
|
|
||||||
table.insert(query, k .. "=" .. tostring(item))
|
|
||||||
end
|
|
||||||
else
|
|
||||||
table.insert(query, k .. "=" .. tostring(v))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
if #query > 0 then
|
|
||||||
if base_url:find("?") then
|
|
||||||
return base_url .. "&" .. table.concat(query, "&")
|
|
||||||
else
|
|
||||||
return base_url .. "?" .. table.concat(query, "&")
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
return base_url
|
|
||||||
end
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
-- Install HTTP module
|
|
||||||
_G.http = http
|
|
||||||
|
|
||||||
-- Clear previous responses when executing scripts
|
|
||||||
local old_execute_script = __execute_script
|
|
||||||
if old_execute_script then
|
|
||||||
__execute_script = function(fn, ctx)
|
|
||||||
-- Clear previous response
|
|
||||||
__http_responses[1] = nil
|
|
||||||
|
|
||||||
-- Execute original function
|
|
||||||
return old_execute_script(fn, ctx)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
`
|
|
86
core/runner/sandbox/Modules.go
Normal file
86
core/runner/sandbox/Modules.go
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
package sandbox
|
||||||
|
|
||||||
|
import (
|
||||||
|
"Moonshark/core/utils/logger"
|
||||||
|
|
||||||
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ModuleFunc is a function that returns a map of module functions
|
||||||
|
type ModuleFunc func() map[string]luajit.GoFunction
|
||||||
|
|
||||||
|
// RegisterModule registers a map of functions as a Lua module
|
||||||
|
func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.GoFunction) error {
|
||||||
|
// Create a new table for the module
|
||||||
|
state.NewTable()
|
||||||
|
|
||||||
|
// Add each function to the module table
|
||||||
|
for fname, f := range funcs {
|
||||||
|
// Push function name
|
||||||
|
state.PushString(fname)
|
||||||
|
|
||||||
|
// Push function
|
||||||
|
if err := state.PushGoFunction(f); err != nil {
|
||||||
|
state.Pop(1) // Pop table
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set table[fname] = f
|
||||||
|
state.SetTable(-3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register the module globally
|
||||||
|
state.SetGlobal(name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CombineInitFuncs combines multiple state initializer functions into one
|
||||||
|
func CombineInitFuncs(funcs ...func(*luajit.State) error) func(*luajit.State) error {
|
||||||
|
return func(state *luajit.State) error {
|
||||||
|
for _, f := range funcs {
|
||||||
|
if f != nil {
|
||||||
|
if err := f(state); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModuleInitFunc creates a state initializer that registers multiple modules
|
||||||
|
func ModuleInitFunc(modules map[string]ModuleFunc) func(*luajit.State) error {
|
||||||
|
return func(state *luajit.State) error {
|
||||||
|
for name, moduleFunc := range modules {
|
||||||
|
if err := RegisterModule(state, name, moduleFunc()); err != nil {
|
||||||
|
logger.Error("Failed to register module %s: %v", name, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterLuaCode registers a Lua code snippet as a module
|
||||||
|
func RegisterLuaCode(state *luajit.State, code string) error {
|
||||||
|
return state.DoString(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterLuaCodeInitFunc returns a StateInitFunc that registers Lua code
|
||||||
|
func RegisterLuaCodeInitFunc(code string) func(*luajit.State) error {
|
||||||
|
return func(state *luajit.State) error {
|
||||||
|
return RegisterLuaCode(state, code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterLuaModuleInitFunc returns a StateInitFunc that registers a Lua module
|
||||||
|
func RegisterLuaModuleInitFunc(name string, code string) func(*luajit.State) error {
|
||||||
|
return func(state *luajit.State) error {
|
||||||
|
// Create name = {} global
|
||||||
|
state.NewTable()
|
||||||
|
state.SetGlobal(name)
|
||||||
|
|
||||||
|
// Then run the module code which will populate it
|
||||||
|
return state.DoString(code)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package runner
|
package sandbox
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -6,24 +6,52 @@ import (
|
||||||
|
|
||||||
"github.com/goccy/go-json"
|
"github.com/goccy/go-json"
|
||||||
"github.com/valyala/bytebufferpool"
|
"github.com/valyala/bytebufferpool"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
"Moonshark/core/utils/logger"
|
"Moonshark/core/utils/logger"
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Global bytecode cache to improve performance
|
||||||
|
var (
|
||||||
|
sandboxBytecode []byte
|
||||||
|
bytecodeOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
// precompileSandbox compiles the sandbox.lua code to bytecode once
|
||||||
|
func precompileSandbox() {
|
||||||
|
tempState := luajit.New()
|
||||||
|
if tempState == nil {
|
||||||
|
logger.Error("Failed to create temporary Lua state for bytecode compilation")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer tempState.Close()
|
||||||
|
defer tempState.Cleanup()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
sandboxBytecode, err = tempState.CompileBytecode(sandboxLua, "sandbox.lua")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to precompile sandbox.lua: %v", err)
|
||||||
|
} else {
|
||||||
|
logger.Debug("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(sandboxBytecode))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Sandbox provides a secure execution environment for Lua scripts
|
// Sandbox provides a secure execution environment for Lua scripts
|
||||||
type Sandbox struct {
|
type Sandbox struct {
|
||||||
modules map[string]any // Custom modules for environment
|
modules map[string]any // Custom modules for environment
|
||||||
debug bool // Enable debug output
|
debug bool // Enable debug output
|
||||||
mu sync.RWMutex // Protects modules
|
mu sync.RWMutex // Protects modules
|
||||||
|
initializers *ModuleInitializers // Module initializers
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSandbox creates a new sandbox environment
|
// NewSandbox creates a new sandbox environment
|
||||||
func NewSandbox() *Sandbox {
|
func NewSandbox() *Sandbox {
|
||||||
return &Sandbox{
|
return &Sandbox{
|
||||||
modules: make(map[string]any, 8), // Pre-allocate with reasonable capacity
|
modules: make(map[string]any, 8), // Pre-allocate with reasonable capacity
|
||||||
debug: false,
|
debug: false,
|
||||||
|
initializers: DefaultInitializers(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,7 +67,7 @@ func (s *Sandbox) debugLog(format string, args ...interface{}) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// debugLog logs a message if debug mode is enabled
|
// debugLogCont logs a continuation message if debug mode is enabled
|
||||||
func (s *Sandbox) debugLogCont(format string, args ...interface{}) {
|
func (s *Sandbox) debugLogCont(format string, args ...interface{}) {
|
||||||
if s.debug {
|
if s.debug {
|
||||||
logger.DebugCont(format, args...)
|
logger.DebugCont(format, args...)
|
||||||
|
@ -60,19 +88,27 @@ func (s *Sandbox) Setup(state *luajit.State, stateIndex int) error {
|
||||||
verbose := stateIndex == 0
|
verbose := stateIndex == 0
|
||||||
|
|
||||||
if verbose {
|
if verbose {
|
||||||
s.debugLog("is setting up...")
|
s.debugLog("Setting up sandbox...")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register modules in the global environment
|
// Initialize modules with the embedded sandbox code
|
||||||
|
if err := InitializeAll(state, s.initializers); err != nil {
|
||||||
|
if verbose {
|
||||||
|
s.debugLog("Failed to initialize sandbox: %v", err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register custom modules in the global environment
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
for name, module := range s.modules {
|
for name, module := range s.modules {
|
||||||
if verbose {
|
if verbose {
|
||||||
s.debugLog("is registering module: %s", name)
|
s.debugLog("Registering module: %s", name)
|
||||||
}
|
}
|
||||||
if err := state.PushValue(module); err != nil {
|
if err := state.PushValue(module); err != nil {
|
||||||
s.mu.RUnlock()
|
s.mu.RUnlock()
|
||||||
if verbose {
|
if verbose {
|
||||||
s.debugLog("failed to register module %s: %v", name, err)
|
s.debugLog("Failed to register module %s: %v", name, err)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -80,60 +116,8 @@ func (s *Sandbox) Setup(state *luajit.State, stateIndex int) error {
|
||||||
}
|
}
|
||||||
s.mu.RUnlock()
|
s.mu.RUnlock()
|
||||||
|
|
||||||
// Initialize environment setup
|
|
||||||
err := state.DoString(`
|
|
||||||
-- Global tables for response handling
|
|
||||||
__http_responses = __http_responses or {}
|
|
||||||
|
|
||||||
-- Create environment inheriting from _G
|
|
||||||
function __create_env(ctx)
|
|
||||||
-- Create environment with metatable inheriting from _G
|
|
||||||
local env = setmetatable({}, {__index = _G})
|
|
||||||
|
|
||||||
-- Add context if provided
|
|
||||||
if ctx then
|
|
||||||
env.ctx = ctx
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Add proper require function to this environment
|
|
||||||
if __setup_require then
|
|
||||||
__setup_require(env)
|
|
||||||
end
|
|
||||||
|
|
||||||
return env
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Execute script with clean environment
|
|
||||||
function __execute_script(fn, ctx)
|
|
||||||
-- Clear previous responses
|
|
||||||
__http_responses[1] = nil
|
|
||||||
|
|
||||||
-- Create environment
|
|
||||||
local env = __create_env(ctx)
|
|
||||||
|
|
||||||
-- Set environment for function
|
|
||||||
setfenv(fn, env)
|
|
||||||
|
|
||||||
-- Execute with protected call
|
|
||||||
local ok, result = pcall(fn)
|
|
||||||
if not ok then
|
|
||||||
error(result, 0)
|
|
||||||
end
|
|
||||||
|
|
||||||
return result
|
|
||||||
end
|
|
||||||
`)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if verbose {
|
|
||||||
s.debugLog("failed to set up...")
|
|
||||||
s.debugLogCont("%v", err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if verbose {
|
if verbose {
|
||||||
s.debugLogCont("Complete")
|
s.debugLogCont("Sandbox setup complete")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -152,6 +136,14 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a
|
||||||
return s.OptimizedExecute(state, bytecode, nil)
|
return s.OptimizedExecute(state, bytecode, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Context represents execution context for a Lua script
|
||||||
|
type Context struct {
|
||||||
|
// Values stores any context values (route params, HTTP request info, etc.)
|
||||||
|
Values map[string]any
|
||||||
|
// RequestCtx for HTTP requests
|
||||||
|
RequestCtx *fasthttp.RequestCtx
|
||||||
|
}
|
||||||
|
|
||||||
// OptimizedExecute runs bytecode with a fasthttp context if available
|
// OptimizedExecute runs bytecode with a fasthttp context if available
|
||||||
func (s *Sandbox) OptimizedExecute(state *luajit.State, bytecode []byte, ctx *Context) (any, error) {
|
func (s *Sandbox) OptimizedExecute(state *luajit.State, bytecode []byte, ctx *Context) (any, error) {
|
||||||
// Use a buffer from the pool for any string operations
|
// Use a buffer from the pool for any string operations
|
|
@ -1,4 +1,4 @@
|
||||||
package runner
|
package sandbox
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
@ -9,6 +9,20 @@ import (
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// UtilModuleInitFunc returns an initializer for the util module
|
||||||
|
func UtilModuleInitFunc() func(*luajit.State) error {
|
||||||
|
return func(state *luajit.State) error {
|
||||||
|
return RegisterModule(state, "util", UtilModuleFunctions())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UtilModuleFunctions returns all functions for the util module
|
||||||
|
func UtilModuleFunctions() map[string]luajit.GoFunction {
|
||||||
|
return map[string]luajit.GoFunction{
|
||||||
|
"generate_token": GenerateToken,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GenerateToken creates a cryptographically secure random token
|
// GenerateToken creates a cryptographically secure random token
|
||||||
func GenerateToken(s *luajit.State) int {
|
func GenerateToken(s *luajit.State) int {
|
||||||
// Get the length from the Lua arguments (default to 32)
|
// Get the length from the Lua arguments (default to 32)
|
||||||
|
@ -42,17 +56,3 @@ func GenerateToken(s *luajit.State) int {
|
||||||
s.PushString(token)
|
s.PushString(token)
|
||||||
return 1 // One return value
|
return 1 // One return value
|
||||||
}
|
}
|
||||||
|
|
||||||
// UtilModuleFunctions returns all functions for the go module
|
|
||||||
func UtilModuleFunctions() map[string]luajit.GoFunction {
|
|
||||||
return map[string]luajit.GoFunction{
|
|
||||||
"generate_token": GenerateToken,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// UtilModuleInitFunc returns an initializer for the go module
|
|
||||||
func UtilModuleInitFunc() StateInitFunc {
|
|
||||||
return func(state *luajit.State) error {
|
|
||||||
return RegisterModule(state, "util", UtilModuleFunctions())
|
|
||||||
}
|
|
||||||
}
|
|
552
core/runner/sandbox/lua/sandbox.lua
Normal file
552
core/runner/sandbox/lua/sandbox.lua
Normal file
|
@ -0,0 +1,552 @@
|
||||||
|
--[[
|
||||||
|
Moonshark Lua Sandbox Environment
|
||||||
|
|
||||||
|
This file contains all the Lua code needed for the sandbox environment,
|
||||||
|
including core modules and utilities. It's designed to be embedded in the
|
||||||
|
Go binary at build time.
|
||||||
|
]]--
|
||||||
|
|
||||||
|
-- Global tables for execution context
|
||||||
|
__http_responses = {}
|
||||||
|
__module_paths = {}
|
||||||
|
__module_bytecode = {}
|
||||||
|
__ready_modules = {}
|
||||||
|
__session_data = {}
|
||||||
|
__session_id = nil
|
||||||
|
__session_modified = false
|
||||||
|
__env_system = {
|
||||||
|
base_env = {}
|
||||||
|
}
|
||||||
|
|
||||||
|
-- ======================================================================
|
||||||
|
-- CORE SANDBOX FUNCTIONALITY
|
||||||
|
-- ======================================================================
|
||||||
|
|
||||||
|
-- Create environment inheriting from _G
|
||||||
|
function __create_env(ctx)
|
||||||
|
-- Create environment with metatable inheriting from _G
|
||||||
|
local env = setmetatable({}, {__index = _G})
|
||||||
|
|
||||||
|
-- Add context if provided
|
||||||
|
if ctx then
|
||||||
|
env.ctx = ctx
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Add proper require function to this environment
|
||||||
|
if __setup_require then
|
||||||
|
__setup_require(env)
|
||||||
|
end
|
||||||
|
|
||||||
|
return env
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Execute script with clean environment
|
||||||
|
function __execute_script(fn, ctx)
|
||||||
|
-- Clear previous responses
|
||||||
|
__http_responses[1] = nil
|
||||||
|
|
||||||
|
-- Reset session modification flag
|
||||||
|
__session_modified = false
|
||||||
|
|
||||||
|
-- Create environment
|
||||||
|
local env = __create_env(ctx)
|
||||||
|
|
||||||
|
-- Set environment for function
|
||||||
|
setfenv(fn, env)
|
||||||
|
|
||||||
|
-- Execute with protected call
|
||||||
|
local ok, result = pcall(fn)
|
||||||
|
if not ok then
|
||||||
|
error(result, 0)
|
||||||
|
end
|
||||||
|
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
|
||||||
|
-- ======================================================================
|
||||||
|
-- MODULE LOADING SYSTEM
|
||||||
|
-- ======================================================================
|
||||||
|
|
||||||
|
-- Setup environment-aware require function
|
||||||
|
function __setup_require(env)
|
||||||
|
-- Create require function specific to this environment
|
||||||
|
env.require = function(modname)
|
||||||
|
-- Check if already loaded
|
||||||
|
if package.loaded[modname] then
|
||||||
|
return package.loaded[modname]
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Check preloaded modules
|
||||||
|
if __ready_modules[modname] then
|
||||||
|
local loader = package.preload[modname]
|
||||||
|
if loader then
|
||||||
|
-- Set environment for loader
|
||||||
|
setfenv(loader, env)
|
||||||
|
|
||||||
|
-- Execute and store result
|
||||||
|
local result = loader()
|
||||||
|
if result == nil then
|
||||||
|
result = true
|
||||||
|
end
|
||||||
|
|
||||||
|
package.loaded[modname] = result
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Direct file load as fallback
|
||||||
|
if __module_paths[modname] then
|
||||||
|
local path = __module_paths[modname]
|
||||||
|
local chunk, err = loadfile(path)
|
||||||
|
if chunk then
|
||||||
|
setfenv(chunk, env)
|
||||||
|
local result = chunk()
|
||||||
|
if result == nil then
|
||||||
|
result = true
|
||||||
|
end
|
||||||
|
package.loaded[modname] = result
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Full path search as last resort
|
||||||
|
local errors = {}
|
||||||
|
for path in package.path:gmatch("[^;]+") do
|
||||||
|
local file_path = path:gsub("?", modname:gsub("%.", "/"))
|
||||||
|
local chunk, err = loadfile(file_path)
|
||||||
|
if chunk then
|
||||||
|
setfenv(chunk, env)
|
||||||
|
local result = chunk()
|
||||||
|
if result == nil then
|
||||||
|
result = true
|
||||||
|
end
|
||||||
|
package.loaded[modname] = result
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
table.insert(errors, "\tno file '" .. file_path .. "'")
|
||||||
|
end
|
||||||
|
|
||||||
|
error("module '" .. modname .. "' not found:\n" .. table.concat(errors, "\n"), 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
return env
|
||||||
|
end
|
||||||
|
|
||||||
|
-- ======================================================================
|
||||||
|
-- HTTP MODULE
|
||||||
|
-- ======================================================================
|
||||||
|
|
||||||
|
-- HTTP module implementation
|
||||||
|
local http = {
|
||||||
|
-- Set HTTP status code
|
||||||
|
set_status = function(code)
|
||||||
|
if type(code) ~= "number" then
|
||||||
|
error("http.set_status: status code must be a number", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
local resp = __http_responses[1] or {}
|
||||||
|
resp.status = code
|
||||||
|
__http_responses[1] = resp
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Set HTTP header
|
||||||
|
set_header = function(name, value)
|
||||||
|
if type(name) ~= "string" or type(value) ~= "string" then
|
||||||
|
error("http.set_header: name and value must be strings", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
local resp = __http_responses[1] or {}
|
||||||
|
resp.headers = resp.headers or {}
|
||||||
|
resp.headers[name] = value
|
||||||
|
__http_responses[1] = resp
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Set content type; set_header helper
|
||||||
|
set_content_type = function(content_type)
|
||||||
|
http.set_header("Content-Type", content_type)
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- HTTP client submodule
|
||||||
|
client = {
|
||||||
|
-- Generic request function
|
||||||
|
request = function(method, url, body, options)
|
||||||
|
if type(method) ~= "string" then
|
||||||
|
error("http.client.request: method must be a string", 2)
|
||||||
|
end
|
||||||
|
if type(url) ~= "string" then
|
||||||
|
error("http.client.request: url must be a string", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Call native implementation
|
||||||
|
local result = __http_request(method, url, body, options)
|
||||||
|
return result
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Simple GET request
|
||||||
|
get = function(url, options)
|
||||||
|
return http.client.request("GET", url, nil, options)
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Simple POST request with automatic content-type
|
||||||
|
post = function(url, body, options)
|
||||||
|
options = options or {}
|
||||||
|
return http.client.request("POST", url, body, options)
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Simple PUT request with automatic content-type
|
||||||
|
put = function(url, body, options)
|
||||||
|
options = options or {}
|
||||||
|
return http.client.request("PUT", url, body, options)
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Simple DELETE request
|
||||||
|
delete = function(url, options)
|
||||||
|
return http.client.request("DELETE", url, nil, options)
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Simple PATCH request
|
||||||
|
patch = function(url, body, options)
|
||||||
|
options = options or {}
|
||||||
|
return http.client.request("PATCH", url, body, options)
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Simple HEAD request
|
||||||
|
head = function(url, options)
|
||||||
|
options = options or {}
|
||||||
|
local old_options = options
|
||||||
|
options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query}
|
||||||
|
local response = http.client.request("HEAD", url, nil, options)
|
||||||
|
return response
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Simple OPTIONS request
|
||||||
|
options = function(url, options)
|
||||||
|
return http.client.request("OPTIONS", url, nil, options)
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Shorthand function to directly get JSON
|
||||||
|
get_json = function(url, options)
|
||||||
|
options = options or {}
|
||||||
|
local response = http.client.get(url, options)
|
||||||
|
if response.ok and response.json then
|
||||||
|
return response.json
|
||||||
|
end
|
||||||
|
return nil, response
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Utility to build a URL with query parameters
|
||||||
|
build_url = function(base_url, params)
|
||||||
|
if not params or type(params) ~= "table" then
|
||||||
|
return base_url
|
||||||
|
end
|
||||||
|
|
||||||
|
local query = {}
|
||||||
|
for k, v in pairs(params) do
|
||||||
|
if type(v) == "table" then
|
||||||
|
for _, item in ipairs(v) do
|
||||||
|
table.insert(query, k .. "=" .. tostring(item))
|
||||||
|
end
|
||||||
|
else
|
||||||
|
table.insert(query, k .. "=" .. tostring(v))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
if #query > 0 then
|
||||||
|
if base_url:find("?") then
|
||||||
|
return base_url .. "&" .. table.concat(query, "&")
|
||||||
|
else
|
||||||
|
return base_url .. "?" .. table.concat(query, "&")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return base_url
|
||||||
|
end
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
-- ======================================================================
|
||||||
|
-- COOKIE MODULE
|
||||||
|
-- ======================================================================
|
||||||
|
|
||||||
|
-- Cookie module implementation
|
||||||
|
local cookie = {
|
||||||
|
-- Set a cookie
|
||||||
|
set = function(name, value, options, ...)
|
||||||
|
if type(name) ~= "string" then
|
||||||
|
error("cookie.set: name must be a string", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Get or create response
|
||||||
|
local resp = __http_responses[1] or {}
|
||||||
|
resp.cookies = resp.cookies or {}
|
||||||
|
__http_responses[1] = resp
|
||||||
|
|
||||||
|
-- Handle options as table or legacy params
|
||||||
|
local opts = {}
|
||||||
|
if type(options) == "table" then
|
||||||
|
opts = options
|
||||||
|
elseif options ~= nil then
|
||||||
|
-- Legacy support: options is actually 'expires'
|
||||||
|
opts.expires = options
|
||||||
|
-- Check for other legacy params (4th-7th args)
|
||||||
|
local args = {...}
|
||||||
|
if args[1] then opts.path = args[1] end
|
||||||
|
if args[2] then opts.domain = args[2] end
|
||||||
|
if args[3] then opts.secure = args[3] end
|
||||||
|
if args[4] ~= nil then opts.http_only = args[4] end
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Create cookie table
|
||||||
|
local cookie = {
|
||||||
|
name = name,
|
||||||
|
value = value or "",
|
||||||
|
path = opts.path or "/",
|
||||||
|
domain = opts.domain
|
||||||
|
}
|
||||||
|
|
||||||
|
-- Handle expiry
|
||||||
|
if opts.expires then
|
||||||
|
if type(opts.expires) == "number" then
|
||||||
|
if opts.expires > 0 then
|
||||||
|
cookie.max_age = opts.expires
|
||||||
|
local now = os.time()
|
||||||
|
cookie.expires = now + opts.expires
|
||||||
|
elseif opts.expires < 0 then
|
||||||
|
cookie.expires = 1
|
||||||
|
cookie.max_age = 0
|
||||||
|
else
|
||||||
|
-- opts.expires == 0: Session cookie
|
||||||
|
-- Do nothing (omitting both expires and max-age creates a session cookie)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Security flags
|
||||||
|
cookie.secure = (opts.secure ~= false)
|
||||||
|
cookie.http_only = (opts.http_only ~= false)
|
||||||
|
|
||||||
|
-- Store in cookies table
|
||||||
|
local n = #resp.cookies + 1
|
||||||
|
resp.cookies[n] = cookie
|
||||||
|
|
||||||
|
return true
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Get a cookie value
|
||||||
|
get = function(name)
|
||||||
|
if type(name) ~= "string" then
|
||||||
|
error("cookie.get: name must be a string", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Access values directly from current environment
|
||||||
|
local env = getfenv(2)
|
||||||
|
|
||||||
|
-- Check if context exists and has cookies
|
||||||
|
if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then
|
||||||
|
return tostring(env.ctx.cookies[name])
|
||||||
|
end
|
||||||
|
|
||||||
|
return nil
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Remove a cookie
|
||||||
|
remove = function(name, path, domain)
|
||||||
|
if type(name) ~= "string" then
|
||||||
|
error("cookie.remove: name must be a string", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Create an expired cookie
|
||||||
|
return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain})
|
||||||
|
end
|
||||||
|
}
|
||||||
|
|
||||||
|
-- ======================================================================
|
||||||
|
-- SESSION MODULE
|
||||||
|
-- ======================================================================
|
||||||
|
|
||||||
|
-- Session module implementation
|
||||||
|
local session = {
|
||||||
|
-- Get a session value
|
||||||
|
get = function(key)
|
||||||
|
if type(key) ~= "string" then
|
||||||
|
error("session.get: key must be a string", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
if __session_data and __session_data[key] then
|
||||||
|
return __session_data[key]
|
||||||
|
end
|
||||||
|
|
||||||
|
return nil
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Set a session value
|
||||||
|
set = function(key, value)
|
||||||
|
if type(key) ~= "string" then
|
||||||
|
error("session.set: key must be a string", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Ensure session data table exists
|
||||||
|
__session_data = __session_data or {}
|
||||||
|
|
||||||
|
-- Store value
|
||||||
|
__session_data[key] = value
|
||||||
|
|
||||||
|
-- Mark session as modified
|
||||||
|
__session_modified = true
|
||||||
|
|
||||||
|
return true
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Delete a session value
|
||||||
|
delete = function(key)
|
||||||
|
if type(key) ~= "string" then
|
||||||
|
error("session.delete: key must be a string", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
if __session_data then
|
||||||
|
__session_data[key] = nil
|
||||||
|
__session_modified = true
|
||||||
|
end
|
||||||
|
|
||||||
|
return true
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Clear all session data
|
||||||
|
clear = function()
|
||||||
|
__session_data = {}
|
||||||
|
__session_modified = true
|
||||||
|
return true
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Get the session ID
|
||||||
|
get_id = function()
|
||||||
|
return __session_id or nil
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Get all session data
|
||||||
|
get_all = function()
|
||||||
|
local result = {}
|
||||||
|
for k, v in pairs(__session_data or {}) do
|
||||||
|
result[k] = v
|
||||||
|
end
|
||||||
|
return result
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Check if session has a key
|
||||||
|
has = function(key)
|
||||||
|
if type(key) ~= "string" then
|
||||||
|
error("session.has: key must be a string", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
return __session_data and __session_data[key] ~= nil
|
||||||
|
end
|
||||||
|
}
|
||||||
|
|
||||||
|
-- ======================================================================
|
||||||
|
-- CSRF MODULE
|
||||||
|
-- ======================================================================
|
||||||
|
|
||||||
|
-- CSRF protection module
|
||||||
|
local csrf = {
|
||||||
|
-- Session key where the token is stored
|
||||||
|
TOKEN_KEY = "_csrf_token",
|
||||||
|
|
||||||
|
-- Default form field name
|
||||||
|
DEFAULT_FIELD = "csrf",
|
||||||
|
|
||||||
|
-- Generate a new CSRF token and store it in the session
|
||||||
|
generate = function(length)
|
||||||
|
-- Default length is 32 characters
|
||||||
|
length = length or 32
|
||||||
|
|
||||||
|
if length < 16 then
|
||||||
|
-- Enforce minimum security
|
||||||
|
length = 16
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Check if we have a session module
|
||||||
|
if not session then
|
||||||
|
error("CSRF protection requires the session module", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
local token = util.generate_token(length)
|
||||||
|
session.set(csrf.TOKEN_KEY, token)
|
||||||
|
return token
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Get the current token or generate a new one
|
||||||
|
token = function()
|
||||||
|
-- Get from session if exists
|
||||||
|
local token = session.get(csrf.TOKEN_KEY)
|
||||||
|
|
||||||
|
-- Generate if needed
|
||||||
|
if not token then
|
||||||
|
token = csrf.generate()
|
||||||
|
end
|
||||||
|
|
||||||
|
return token
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Generate a hidden form field with the CSRF token
|
||||||
|
field = function(field_name)
|
||||||
|
field_name = field_name or csrf.DEFAULT_FIELD
|
||||||
|
local token = csrf.token()
|
||||||
|
return string.format('<input type="hidden" name="%s" value="%s">', field_name, token)
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Verify a given token against the session token
|
||||||
|
verify = function(token, field_name)
|
||||||
|
field_name = field_name or csrf.DEFAULT_FIELD
|
||||||
|
|
||||||
|
local env = getfenv(2)
|
||||||
|
|
||||||
|
local form = nil
|
||||||
|
if env.ctx and env.ctx.form then
|
||||||
|
form = env.ctx.form
|
||||||
|
else
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
|
token = token or form[field_name]
|
||||||
|
if not token then
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
|
local session_token = session.get(csrf.TOKEN_KEY)
|
||||||
|
if not session_token then
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Constant-time comparison to prevent timing attacks
|
||||||
|
-- This is safe since Lua strings are immutable
|
||||||
|
if #token ~= #session_token then
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
|
local result = true
|
||||||
|
for i = 1, #token do
|
||||||
|
if token:sub(i, i) ~= session_token:sub(i, i) then
|
||||||
|
result = false
|
||||||
|
-- Don't break early - continue to prevent timing attacks
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
}
|
||||||
|
|
||||||
|
-- ======================================================================
|
||||||
|
-- REGISTER MODULES GLOBALLY
|
||||||
|
-- ======================================================================
|
||||||
|
|
||||||
|
-- Install modules in global scope
|
||||||
|
_G.http = http
|
||||||
|
_G.cookie = cookie
|
||||||
|
_G.session = session
|
||||||
|
_G.csrf = csrf
|
||||||
|
|
||||||
|
-- Register modules in sandbox base environment
|
||||||
|
__env_system.base_env.http = http
|
||||||
|
__env_system.base_env.cookie = cookie
|
||||||
|
__env_system.base_env.session = session
|
||||||
|
__env_system.base_env.csrf = csrf
|
|
@ -87,7 +87,6 @@ func (sm *SessionManager) GetSession(id string) *Session {
|
||||||
func (sm *SessionManager) CreateSession() *Session {
|
func (sm *SessionManager) CreateSession() *Session {
|
||||||
id := generateSessionID()
|
id := generateSessionID()
|
||||||
|
|
||||||
// Create new session
|
|
||||||
session := NewSession(id)
|
session := NewSession(id)
|
||||||
data, _ := json.Marshal(session)
|
data, _ := json.Marshal(session)
|
||||||
sm.cache.Set([]byte(id), data)
|
sm.cache.Set([]byte(id), data)
|
||||||
|
@ -95,6 +94,12 @@ func (sm *SessionManager) CreateSession() *Session {
|
||||||
return session
|
return session
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveSession persists a session back to the cache
|
||||||
|
func (sm *SessionManager) SaveSession(session *Session) {
|
||||||
|
data, _ := json.Marshal(session)
|
||||||
|
sm.cache.Set([]byte(session.ID), data)
|
||||||
|
}
|
||||||
|
|
||||||
// DestroySession removes a session
|
// DestroySession removes a session
|
||||||
func (sm *SessionManager) DestroySession(id string) {
|
func (sm *SessionManager) DestroySession(id string) {
|
||||||
sm.cache.Del([]byte(id))
|
sm.cache.Del([]byte(id))
|
||||||
|
|
|
@ -18,13 +18,13 @@ var (
|
||||||
|
|
||||||
// Session stores data for a single user session
|
// Session stores data for a single user session
|
||||||
type Session struct {
|
type Session struct {
|
||||||
ID string
|
ID string `json:"id"`
|
||||||
Data map[string]any
|
Data map[string]any `json:"data"`
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
mu sync.RWMutex // Protect concurrent access to Data
|
mu sync.RWMutex `json:"-"`
|
||||||
maxValueSize int // Maximum size of individual values in bytes
|
maxValueSize int `json:"max_value_size"`
|
||||||
totalDataSize int // Track total size of all data
|
totalDataSize int `json:"total_data_size"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSession creates a new session with the given ID
|
// NewSession creates a new session with the given ID
|
||||||
|
|
Loading…
Reference in New Issue
Block a user