This commit is contained in:
Sky Johnson 2025-04-07 21:59:11 -05:00
parent 6f020932c4
commit c0b493b6bc
19 changed files with 1347 additions and 1100 deletions

View File

@ -185,7 +185,7 @@ func (s *Moonshark) initRunner() error {
runner.WithPoolSize(s.Config.Runner.PoolSize),
runner.WithLibDirs(s.Config.Dirs.Libs...),
runner.WithSessionManager(sessionManager),
runner.WithCSRFProtection(),
http.WithCSRFProtection(),
}
// Add debug option conditionally

View File

@ -1,12 +1,119 @@
package http
import (
"Moonshark/core/runner"
"Moonshark/core/utils"
"Moonshark/core/utils/logger"
"crypto/subtle"
"github.com/valyala/fasthttp"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// ValidateCSRFToken checks if the CSRF token is valid for a request
func ValidateCSRFToken(state *luajit.State, ctx *runner.Context) bool {
// Only validate for form submissions
method, ok := ctx.Get("method").(string)
if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") {
return true
}
// Get form data
formData, ok := ctx.Get("form").(map[string]any)
if !ok || formData == nil {
logger.Warning("CSRF validation failed: no form data")
return false
}
// Get token from form
formToken, ok := formData["csrf"].(string)
if !ok || formToken == "" {
logger.Warning("CSRF validation failed: no token in form")
return false
}
// Get session token
state.GetGlobal("session")
if state.IsNil(-1) {
state.Pop(1)
logger.Warning("CSRF validation failed: session module not available")
return false
}
state.GetField(-1, "get")
if !state.IsFunction(-1) {
state.Pop(2)
logger.Warning("CSRF validation failed: session.get not available")
return false
}
state.PushCopy(-1) // Duplicate function
state.PushString("_csrf_token")
if err := state.Call(1, 1); err != nil {
state.Pop(3) // Pop error, function and session table
logger.Warning("CSRF validation failed: %v", err)
return false
}
if state.IsNil(-1) {
state.Pop(3) // Pop nil, function and session table
logger.Warning("CSRF validation failed: no token in session")
return false
}
sessionToken := state.ToString(-1)
state.Pop(3) // Pop token, function and session table
// Constant-time comparison to prevent timing attacks
return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1
}
// WithCSRFProtection creates a runner option to add CSRF protection
func WithCSRFProtection() runner.RunnerOption {
return func(r *runner.Runner) {
r.AddInitHook(func(state *luajit.State, ctx *runner.Context) error {
// Get request method
method, ok := ctx.Get("method").(string)
if !ok {
return nil
}
// Only validate for form submissions
if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" {
return nil
}
// Check for form data
form, ok := ctx.Get("form").(map[string]any)
if !ok || form == nil {
return nil
}
// Validate CSRF token
if !ValidateCSRFToken(state, ctx) {
return ErrCSRFValidationFailed
}
return nil
})
}
}
// Error for CSRF validation failure
var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"}
// CSRFError represents a CSRF validation error
type CSRFError struct {
message string
}
// Error implements the error interface
func (e *CSRFError) Error() string {
return e.message
}
// HandleCSRFError handles a CSRF validation error
func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) {
method := string(ctx.Method())

View File

@ -8,6 +8,7 @@ import (
"Moonshark/core/metadata"
"Moonshark/core/routers"
"Moonshark/core/runner"
"Moonshark/core/runner/sandbox"
"Moonshark/core/utils"
"Moonshark/core/utils/config"
"Moonshark/core/utils/logger"
@ -226,7 +227,7 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
// Special handling for CSRF error
if err != nil {
if csrfErr, ok := err.(*runner.CSRFError); ok {
if csrfErr, ok := err.(*CSRFError); ok {
logger.Warning("CSRF error executing Lua route: %v", csrfErr)
HandleCSRFError(ctx, s.errorConfig)
return
@ -258,8 +259,8 @@ func writeResponse(ctx *fasthttp.RequestCtx, result any) {
}
// Check for HTTPResponse type
if httpResp, ok := result.(*runner.HTTPResponse); ok {
defer runner.ReleaseResponse(httpResp)
if httpResp, ok := result.(*sandbox.HTTPResponse); ok {
defer sandbox.ReleaseResponse(httpResp)
// Set response headers
for name, value := range httpResp.Headers {

View File

@ -3,6 +3,8 @@ package runner
import (
"sync"
"maps"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
)
@ -24,7 +26,7 @@ type Context struct {
// Context pool to reduce allocations
var contextPool = sync.Pool{
New: func() interface{} {
New: func() any {
return &Context{
Values: make(map[string]any, 16), // Pre-allocate with reasonable capacity
}
@ -115,9 +117,7 @@ func (c *Context) All() map[string]any {
defer c.mu.RUnlock()
result := make(map[string]any, len(c.Values))
for k, v := range c.Values {
result[k] = v
}
maps.Copy(result, c.Values)
return result
}

View File

@ -1,187 +0,0 @@
package runner
import (
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/valyala/fasthttp"
)
// extractCookie grabs cookies from the Lua state
func extractCookie(state *luajit.State) *fasthttp.Cookie {
cookie := fasthttp.AcquireCookie()
// Get name
state.GetField(-1, "name")
if !state.IsString(-1) {
state.Pop(1)
fasthttp.ReleaseCookie(cookie)
return nil // Name is required
}
cookie.SetKey(state.ToString(-1))
state.Pop(1)
// Get value
state.GetField(-1, "value")
if state.IsString(-1) {
cookie.SetValue(state.ToString(-1))
}
state.Pop(1)
// Get path
state.GetField(-1, "path")
if state.IsString(-1) {
cookie.SetPath(state.ToString(-1))
} else {
cookie.SetPath("/") // Default path
}
state.Pop(1)
// Get domain
state.GetField(-1, "domain")
if state.IsString(-1) {
cookie.SetDomain(state.ToString(-1))
}
state.Pop(1)
// Get expires
state.GetField(-1, "expires")
if state.IsNumber(-1) {
expiry := int64(state.ToNumber(-1))
cookie.SetExpire(time.Unix(expiry, 0))
}
state.Pop(1)
// Get max age
state.GetField(-1, "max_age")
if state.IsNumber(-1) {
cookie.SetMaxAge(int(state.ToNumber(-1)))
}
state.Pop(1)
// Get secure
state.GetField(-1, "secure")
if state.IsBoolean(-1) {
cookie.SetSecure(state.ToBoolean(-1))
}
state.Pop(1)
// Get http only
state.GetField(-1, "http_only")
if state.IsBoolean(-1) {
cookie.SetHTTPOnly(state.ToBoolean(-1))
}
state.Pop(1)
return cookie
}
// LuaCookieModule provides cookie functionality to Lua scripts
const LuaCookieModule = `
-- Cookie module implementation
local cookie = {
-- Set a cookie
set = function(name, value, options, ...)
if type(name) ~= "string" then
error("cookie.set: name must be a string", 2)
end
-- Get or create response
local resp = __http_responses[1] or {}
resp.cookies = resp.cookies or {}
__http_responses[1] = resp
-- Handle options as table or legacy params
local opts = {}
if type(options) == "table" then
opts = options
elseif options ~= nil then
-- Legacy support: options is actually 'expires'
opts.expires = options
-- Check for other legacy params (4th-7th args)
local args = {...}
if args[1] then opts.path = args[1] end
if args[2] then opts.domain = args[2] end
if args[3] then opts.secure = args[3] end
if args[4] ~= nil then opts.http_only = args[4] end
end
-- Create cookie table
local cookie = {
name = name,
value = value or "",
path = opts.path or "/",
domain = opts.domain
}
-- Handle expiry
if opts.expires then
if type(opts.expires) == "number" then
if opts.expires > 0 then
cookie.max_age = opts.expires
local now = os.time()
cookie.expires = now + opts.expires
elseif opts.expires < 0 then
cookie.expires = 1
cookie.max_age = 0
else
-- opts.expires == 0: Session cookie
-- Do nothing (omitting both expires and max-age creates a session cookie)
end
end
end
-- Security flags
cookie.secure = (opts.secure ~= false)
cookie.http_only = (opts.http_only ~= false)
-- Store in cookies table
local n = #resp.cookies + 1
resp.cookies[n] = cookie
return true
end,
-- Get a cookie value
get = function(name)
if type(name) ~= "string" then
error("cookie.get: name must be a string", 2)
end
-- Access values directly from current environment
local env = getfenv(2)
-- Check if context exists and has cookies
if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then
return tostring(env.ctx.cookies[name])
end
return nil
end,
-- Remove a cookie
remove = function(name, path, domain)
if type(name) ~= "string" then
error("cookie.remove: name must be a string", 2)
end
-- Create an expired cookie
return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain})
end
}
-- Install cookie module
_G.cookie = cookie
-- Make sure the cookie module is accessible in sandbox
if __env_system and __env_system.base_env then
__env_system.base_env.cookie = cookie
end
`
// CookieModuleInitFunc returns an initializer for the cookie module
func CookieModuleInitFunc() StateInitFunc {
return func(state *luajit.State) error {
return state.DoString(LuaCookieModule)
}
}

77
core/runner/Cookies.go Normal file
View File

@ -0,0 +1,77 @@
package runner
import (
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/valyala/fasthttp"
)
// extractCookie grabs cookies from the Lua state
func extractCookie(state *luajit.State) *fasthttp.Cookie {
cookie := fasthttp.AcquireCookie()
// Get name
state.GetField(-1, "name")
if !state.IsString(-1) {
state.Pop(1)
fasthttp.ReleaseCookie(cookie)
return nil // Name is required
}
cookie.SetKey(state.ToString(-1))
state.Pop(1)
// Get value
state.GetField(-1, "value")
if state.IsString(-1) {
cookie.SetValue(state.ToString(-1))
}
state.Pop(1)
// Get path
state.GetField(-1, "path")
if state.IsString(-1) {
cookie.SetPath(state.ToString(-1))
} else {
cookie.SetPath("/") // Default path
}
state.Pop(1)
// Get domain
state.GetField(-1, "domain")
if state.IsString(-1) {
cookie.SetDomain(state.ToString(-1))
}
state.Pop(1)
// Get expires
state.GetField(-1, "expires")
if state.IsNumber(-1) {
expiry := int64(state.ToNumber(-1))
cookie.SetExpire(time.Unix(expiry, 0))
}
state.Pop(1)
// Get max age
state.GetField(-1, "max_age")
if state.IsNumber(-1) {
cookie.SetMaxAge(int(state.ToNumber(-1)))
}
state.Pop(1)
// Get secure
state.GetField(-1, "secure")
if state.IsBoolean(-1) {
cookie.SetSecure(state.ToBoolean(-1))
}
state.Pop(1)
// Get http only
state.GetField(-1, "http_only")
if state.IsBoolean(-1) {
cookie.SetHTTPOnly(state.ToBoolean(-1))
}
state.Pop(1)
return cookie
}

View File

@ -1,6 +1,7 @@
package runner
import (
"Moonshark/core/runner/sandbox"
"Moonshark/core/utils/logger"
"fmt"
"strings"
@ -265,17 +266,21 @@ func init() {
GlobalRegistry.EnableDebug() // Enable debugging by default
logger.Debug("[ModuleRegistry] Registering core modules...")
GlobalRegistry.Register("util", UtilModuleInitFunc())
GlobalRegistry.Register("http", HTTPModuleInitFunc())
GlobalRegistry.RegisterWithDependencies("cookie", CookieModuleInitFunc(), []string{"http"})
GlobalRegistry.RegisterWithDependencies("csrf", CSRFModuleInitFunc(), []string{"util"})
// Register core modules - these now point to the sandbox implementations
GlobalRegistry.Register("util", func(state *luajit.State) error {
return sandbox.UtilModuleInitFunc()(state)
})
GlobalRegistry.Register("http", func(state *luajit.State) error {
return sandbox.HTTPModuleInitFunc()(state)
})
// Set explicit initialization order
GlobalRegistry.SetInitOrder([]string{
"util", // First: core utilities
"http", // Second: HTTP functionality
"cookie", // Third: Cookie functionality (uses HTTP)
"csrf", // Fourth: CSRF protection (uses go and possibly session)
"session", // Third: Session functionality
"csrf", // Fourth: CSRF protection
})
logger.DebugCont("Core modules registered successfully")

View File

@ -1,219 +0,0 @@
package runner
import (
"crypto/subtle"
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// LuaCSRFModule provides CSRF protection functionality to Lua scripts
const LuaCSRFModule = `
-- CSRF protection module
local csrf = {
-- Session key where the token is stored
TOKEN_KEY = "_csrf_token",
-- Default form field name
DEFAULT_FIELD = "csrf",
-- Generate a new CSRF token and store it in the session
generate = function(length)
-- Default length is 32 characters
length = length or 32
if length < 16 then
-- Enforce minimum security
length = 16
end
-- Check if we have a session module
if not session then
error("CSRF protection requires the session module", 2)
end
local token = util.generate_token(length)
session.set(csrf.TOKEN_KEY, token)
return token
end,
-- Get the current token or generate a new one
token = function()
-- Get from session if exists
local token = session.get(csrf.TOKEN_KEY)
-- Generate if needed
if not token then
token = csrf.generate()
end
return token
end,
-- Generate a hidden form field with the CSRF token
field = function(field_name)
field_name = field_name or csrf.DEFAULT_FIELD
local token = csrf.token()
return string.format('<input type="hidden" name="%s" value="%s">', field_name, token)
end,
-- Verify a given token against the session token
verify = function(token, field_name)
field_name = field_name or csrf.DEFAULT_FIELD
local env = getfenv(2)
local form = nil
if env.ctx and env.ctx.form then
form = env.ctx.form
else
return false
end
token = token or form[field_name]
if not token then
return false
end
local session_token = session.get(csrf.TOKEN_KEY)
if not session_token then
return false
end
-- Constant-time comparison to prevent timing attacks
-- This is safe since Lua strings are immutable
if #token ~= #session_token then
return false
end
local result = true
for i = 1, #token do
if token:sub(i, i) ~= session_token:sub(i, i) then
result = false
-- Don't break early - continue to prevent timing attacks
end
end
return result
end
}
-- Install CSRF module
_G.csrf = csrf
-- Make sure the CSRF module is accessible in sandbox
if __env_system and __env_system.base_env then
__env_system.base_env.csrf = csrf
end
`
// CSRFModuleInitFunc returns an initializer for the CSRF module
func CSRFModuleInitFunc() StateInitFunc {
return func(state *luajit.State) error {
return state.DoString(LuaCSRFModule)
}
}
// ValidateCSRFToken checks if the CSRF token is valid for a request
func ValidateCSRFToken(state *luajit.State, ctx *Context) bool {
// Only validate for form submissions
method, ok := ctx.Get("method").(string)
if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") {
return true
}
// Get form data
formData, ok := ctx.Get("form").(map[string]any)
if !ok || formData == nil {
logger.Warning("CSRF validation failed: no form data")
return false
}
// Get token from form
formToken, ok := formData["csrf"].(string)
if !ok || formToken == "" {
logger.Warning("CSRF validation failed: no token in form")
return false
}
// Get session token
state.GetGlobal("session")
if state.IsNil(-1) {
state.Pop(1)
logger.Warning("CSRF validation failed: session module not available")
return false
}
state.GetField(-1, "get")
if !state.IsFunction(-1) {
state.Pop(2)
logger.Warning("CSRF validation failed: session.get not available")
return false
}
state.PushCopy(-1) // Duplicate function
state.PushString("_csrf_token")
if err := state.Call(1, 1); err != nil {
state.Pop(3) // Pop error, function and session table
logger.Warning("CSRF validation failed: %v", err)
return false
}
if state.IsNil(-1) {
state.Pop(3) // Pop nil, function and session table
logger.Warning("CSRF validation failed: no token in session")
return false
}
sessionToken := state.ToString(-1)
state.Pop(3) // Pop token, function and session table
// Constant-time comparison to prevent timing attacks
return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1
}
// WithCSRFProtection creates a runner option to add CSRF protection
func WithCSRFProtection() RunnerOption {
return func(r *Runner) {
r.AddInitHook(func(state *luajit.State, ctx *Context) error {
// Get request method
method, ok := ctx.Get("method").(string)
if !ok {
return nil
}
// Only validate for form submissions
if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" {
return nil
}
// Check for form data
form, ok := ctx.Get("form").(map[string]any)
if !ok || form == nil {
return nil
}
// Validate CSRF token
if !ValidateCSRFToken(state, ctx) {
return ErrCSRFValidationFailed
}
return nil
})
}
}
// Error for CSRF validation failure
var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"}
// CSRFError represents a CSRF validation error
type CSRFError struct {
message string
}
// Error implements the error interface
func (e *CSRFError) Error() string {
return e.message
}

View File

@ -12,6 +12,7 @@ import (
"github.com/panjf2000/ants/v2"
"github.com/valyala/bytebufferpool"
"Moonshark/core/runner/sandbox"
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
@ -31,7 +32,7 @@ type RunnerOption func(*Runner)
// State wraps a Lua state with its sandbox
type State struct {
L *luajit.State // The Lua state
sandbox *Sandbox // Associated sandbox
sandbox *sandbox.Sandbox // Associated sandbox
index int // Index for debugging
inUse bool // Whether the state is currently in use
initTime time.Time // When this state was initialized
@ -217,7 +218,7 @@ func (r *Runner) createState(index int) (*State, error) {
}
// Create sandbox
sandbox := NewSandbox()
sandbox := sandbox.NewSandbox()
if r.debug && verbose {
sandbox.EnableDebug()
}

View File

@ -1,177 +0,0 @@
package runner
import (
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// LuaSessionModule provides session functionality to Lua scripts
const LuaSessionModule = `
-- Global table to store session data
__session_data = __session_data or {}
__session_id = __session_id or nil
__session_modified = false
-- Session module implementation
local session = {
-- Get a session value
get = function(key)
if type(key) ~= "string" then
error("session.get: key must be a string", 2)
end
if __session_data and __session_data[key] then
return __session_data[key]
end
return nil
end,
-- Set a session value
set = function(key, value)
if type(key) ~= "string" then
error("session.set: key must be a string", 2)
end
-- Ensure session data table exists
__session_data = __session_data or {}
-- Store value
__session_data[key] = value
-- Mark session as modified
__session_modified = true
return true
end,
-- Delete a session value
delete = function(key)
if type(key) ~= "string" then
error("session.delete: key must be a string", 2)
end
if __session_data then
__session_data[key] = nil
__session_modified = true
end
return true
end,
-- Clear all session data
clear = function()
__session_data = {}
__session_modified = true
return true
end,
-- Get the session ID
get_id = function()
return __session_id or nil
end,
-- Get all session data
get_all = function()
local result = {}
for k, v in pairs(__session_data or {}) do
result[k] = v
end
return result
end,
-- Check if session has a key
has = function(key)
if type(key) ~= "string" then
error("session.has: key must be a string", 2)
end
return __session_data and __session_data[key] ~= nil
end
}
-- Install session module
_G.session = session
-- Make sure the session module is accessible in sandbox
if __env_system and __env_system.base_env then
__env_system.base_env.session = session
end
-- Hook into script execution to preserve session state
local old_execute_script = __execute_script
if old_execute_script then
__execute_script = function(fn, ctx)
-- Reset modification flag at the start of request
__session_modified = false
-- Execute original function
return old_execute_script(fn, ctx)
end
end
`
// GetSessionData extracts session data from Lua state
func GetSessionData(state *luajit.State) (string, map[string]any, bool) {
// Check if session was modified
state.GetGlobal("__session_modified")
modified := state.ToBoolean(-1)
state.Pop(1)
if !modified {
return "", nil, false
}
// Get session ID
state.GetGlobal("__session_id")
sessionID := state.ToString(-1)
state.Pop(1)
// Get session data
state.GetGlobal("__session_data")
if !state.IsTable(-1) {
state.Pop(1)
return sessionID, nil, false
}
data, err := state.ToTable(-1)
state.Pop(1)
if err != nil {
logger.Error("Failed to extract session data: %v", err)
return sessionID, nil, false
}
return sessionID, data, true
}
// SetSessionData sets session data in Lua state
func SetSessionData(state *luajit.State, sessionID string, data map[string]any) error {
// Set session ID
state.PushString(sessionID)
state.SetGlobal("__session_id")
// Set session data
if data == nil {
data = make(map[string]any)
}
if err := state.PushTable(data); err != nil {
return err
}
state.SetGlobal("__session_data")
// Reset modification flag
state.PushBoolean(false)
state.SetGlobal("__session_modified")
return nil
}
// SessionModuleInitFunc returns an initializer for the session module
func SessionModuleInitFunc() StateInitFunc {
return func(state *luajit.State) error {
return state.DoString(LuaSessionModule)
}
}

View File

@ -3,6 +3,7 @@ package runner
import (
"github.com/valyala/fasthttp"
"Moonshark/core/runner/sandbox"
"Moonshark/core/sessions"
"Moonshark/core/utils/logger"
@ -40,9 +41,6 @@ func WithSessionManager(manager *sessions.SessionManager) RunnerOption {
return func(r *Runner) {
handler := NewSessionHandler(manager)
// Register the session module
RegisterCoreModule("session", SessionModuleInitFunc())
// Add hooks to the runner
r.AddInitHook(handler.preRequestHook)
r.AddFinalizeHook(handler.postRequestHook)
@ -140,8 +138,10 @@ func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, resu
session.Set(k, v)
}
h.manager.SaveSession(session)
// Add session cookie to result if it's an HTTP response
if httpResp, ok := result.(*HTTPResponse); ok {
if httpResp, ok := result.(*sandbox.HTTPResponse); ok {
h.addSessionCookie(httpResp, modifiedID)
}
@ -150,7 +150,7 @@ func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, resu
}
// addSessionCookie adds a session cookie to an HTTP response
func (h *SessionHandler) addSessionCookie(resp *HTTPResponse, sessionID string) {
func (h *SessionHandler) addSessionCookie(resp *sandbox.HTTPResponse, sessionID string) {
// Get cookie options
opts := h.manager.CookieOptions()
@ -184,3 +184,60 @@ func (h *SessionHandler) addSessionCookie(resp *HTTPResponse, sessionID string)
resp.Cookies = append(resp.Cookies, cookie)
}
// GetSessionData extracts session data from Lua state
func GetSessionData(state *luajit.State) (string, map[string]any, bool) {
// Check if session was modified
state.GetGlobal("__session_modified")
modified := state.ToBoolean(-1)
state.Pop(1)
if !modified {
return "", nil, false
}
// Get session ID
state.GetGlobal("__session_id")
sessionID := state.ToString(-1)
state.Pop(1)
// Get session data
state.GetGlobal("__session_data")
if !state.IsTable(-1) {
state.Pop(1)
return sessionID, nil, false
}
data, err := state.ToTable(-1)
state.Pop(1)
if err != nil {
logger.Error("Failed to extract session data: %v", err)
return sessionID, nil, false
}
return sessionID, data, true
}
// SetSessionData sets session data in Lua state
func SetSessionData(state *luajit.State, sessionID string, data map[string]any) error {
// Set session ID
state.PushString(sessionID)
state.SetGlobal("__session_id")
// Set session data
if data == nil {
data = make(map[string]any)
}
if err := state.PushTable(data); err != nil {
return err
}
state.SetGlobal("__session_data")
// Reset modification flag
state.PushBoolean(false)
state.SetGlobal("__session_modified")
return nil
}

View File

@ -0,0 +1,98 @@
package sandbox
import (
_ "embed"
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
//go:embed lua/sandbox.lua
var sandboxLua string
// InitializeSandbox loads the embedded Lua sandbox code into a Lua state
func InitializeSandbox(state *luajit.State) error {
// Compile once, use many times
bytecodeOnce.Do(precompileSandbox)
if sandboxBytecode != nil {
logger.Debug("Loading sandbox.lua from precompiled bytecode")
return state.LoadAndRunBytecode(sandboxBytecode, "sandbox.lua")
}
// Fallback if compilation failed
logger.Warning("Using non-precompiled sandbox.lua (bytecode compilation failed)")
return state.DoString(sandboxLua)
}
// ModuleInitializers stores initializer functions for core modules
type ModuleInitializers struct {
HTTP func(*luajit.State) error
Util func(*luajit.State) error
Session func(*luajit.State) error
Cookie func(*luajit.State) error
CSRF func(*luajit.State) error
}
// DefaultInitializers returns the default set of initializers
func DefaultInitializers() *ModuleInitializers {
return &ModuleInitializers{
HTTP: func(state *luajit.State) error {
// Register the native Go function first
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
logger.Error("[HTTP Module] Failed to register __http_request function: %v", err)
return err
}
return nil
},
Util: func(state *luajit.State) error {
// Register util functions
return RegisterModule(state, "util", UtilModuleFunctions())
},
Session: func(state *luajit.State) error {
// Session doesn't need special initialization
return nil
},
Cookie: func(state *luajit.State) error {
// Cookie doesn't need special initialization
return nil
},
CSRF: func(state *luajit.State) error {
// CSRF doesn't need special initialization
return nil
},
}
}
// InitializeAll initializes all modules in the Lua state
func InitializeAll(state *luajit.State, initializers *ModuleInitializers) error {
// Set up dependencies first
if err := initializers.Util(state); err != nil {
return err
}
if err := initializers.HTTP(state); err != nil {
return err
}
// Load the embedded sandbox code
if err := InitializeSandbox(state); err != nil {
return err
}
// Initialize the rest of the modules
if err := initializers.Session(state); err != nil {
return err
}
if err := initializers.Cookie(state); err != nil {
return err
}
if err := initializers.CSRF(state); err != nil {
return err
}
return nil
}

View File

@ -1,4 +1,4 @@
package runner
package sandbox
import (
"context"
@ -28,7 +28,7 @@ type HTTPResponse struct {
// Response pool to reduce allocations
var responsePool = sync.Pool{
New: func() interface{} {
New: func() any {
return &HTTPResponse{
Status: 200,
Headers: make(map[string]string, 8), // Pre-allocate with reasonable capacity
@ -37,36 +37,6 @@ var responsePool = sync.Pool{
},
}
// NewHTTPResponse creates a default HTTP response, potentially reusing one from the pool
func NewHTTPResponse() *HTTPResponse {
return responsePool.Get().(*HTTPResponse)
}
// ReleaseResponse returns the response to the pool after clearing its values
func ReleaseResponse(resp *HTTPResponse) {
if resp == nil {
return
}
// Clear all values to prevent data leakage
resp.Status = 200 // Reset to default
// Clear headers
for k := range resp.Headers {
delete(resp.Headers, k)
}
// Clear cookies
resp.Cookies = resp.Cookies[:0] // Keep capacity but set length to 0
// Clear body
resp.Body = nil
responsePool.Put(resp)
}
// ---------- HTTP CLIENT FUNCTIONALITY ----------
// Default HTTP client with sensible timeout
var defaultFastClient fasthttp.Client = fasthttp.Client{
MaxConnsPerHost: 1024,
@ -96,8 +66,256 @@ var DefaultHTTPClientConfig = HTTPClientConfig{
AllowRemote: true,
}
// Function name constant to ensure consistency
const httpRequestFuncName = "__http_request"
// NewHTTPResponse creates a default HTTP response, potentially reusing one from the pool
func NewHTTPResponse() *HTTPResponse {
return responsePool.Get().(*HTTPResponse)
}
// ReleaseResponse returns the response to the pool after clearing its values
func ReleaseResponse(resp *HTTPResponse) {
if resp == nil {
return
}
// Clear all values to prevent data leakage
resp.Status = 200 // Reset to default
// Clear headers
for k := range resp.Headers {
delete(resp.Headers, k)
}
// Clear cookies
resp.Cookies = resp.Cookies[:0] // Keep capacity but set length to 0
// Clear body
resp.Body = nil
responsePool.Put(resp)
}
// HTTPModuleInitFunc returns an initializer function for the HTTP module
func HTTPModuleInitFunc() func(*luajit.State) error {
return func(state *luajit.State) error {
// Register the native Go function first
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
logger.Error("[HTTP Module] Failed to register __http_request function")
logger.ErrorCont("%v", err)
return err
}
// Set up default HTTP client configuration
setupHTTPClientConfig(state)
return nil
}
}
// Helper to set up HTTP client config
func setupHTTPClientConfig(state *luajit.State) {
state.NewTable()
state.PushNumber(float64(DefaultHTTPClientConfig.MaxTimeout / time.Second))
state.SetField(-2, "max_timeout")
state.PushNumber(float64(DefaultHTTPClientConfig.DefaultTimeout / time.Second))
state.SetField(-2, "default_timeout")
state.PushNumber(float64(DefaultHTTPClientConfig.MaxResponseSize))
state.SetField(-2, "max_response_size")
state.PushBoolean(DefaultHTTPClientConfig.AllowRemote)
state.SetField(-2, "allow_remote")
state.SetGlobal("__http_client_config")
}
// GetHTTPResponse extracts the HTTP response from Lua state
func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) {
response := NewHTTPResponse()
// Get response table
state.GetGlobal("__http_responses")
if state.IsNil(-1) {
state.Pop(1)
ReleaseResponse(response) // Return unused response to pool
return nil, false
}
// Check for response at thread index
state.PushNumber(1)
state.GetTable(-2)
if state.IsNil(-1) {
state.Pop(2)
ReleaseResponse(response) // Return unused response to pool
return nil, false
}
// Get status
state.GetField(-1, "status")
if state.IsNumber(-1) {
response.Status = int(state.ToNumber(-1))
}
state.Pop(1)
// Get headers
state.GetField(-1, "headers")
if state.IsTable(-1) {
// Iterate through headers table
state.PushNil() // Start iteration
for state.Next(-2) {
// Stack has key at -2 and value at -1
if state.IsString(-2) && state.IsString(-1) {
key := state.ToString(-2)
value := state.ToString(-1)
response.Headers[key] = value
}
state.Pop(1) // Pop value, leave key for next iteration
}
}
state.Pop(1)
// Get cookies
state.GetField(-1, "cookies")
if state.IsTable(-1) {
// Iterate through cookies array
length := state.GetTableLength(-1)
for i := 1; i <= length; i++ {
state.PushNumber(float64(i))
state.GetTable(-2)
if state.IsTable(-1) {
cookie := extractCookie(state)
if cookie != nil {
response.Cookies = append(response.Cookies, cookie)
}
}
state.Pop(1)
}
}
state.Pop(1)
// Clean up
state.Pop(2) // Pop response table and __http_responses
return response, true
}
// ApplyHTTPResponse applies an HTTP response to a fasthttp.RequestCtx
func ApplyHTTPResponse(httpResp *HTTPResponse, ctx *fasthttp.RequestCtx) {
// Set status code
ctx.SetStatusCode(httpResp.Status)
// Set headers
for name, value := range httpResp.Headers {
ctx.Response.Header.Set(name, value)
}
// Set cookies
for _, cookie := range httpResp.Cookies {
ctx.Response.Header.SetCookie(cookie)
}
// Process the body based on its type
if httpResp.Body == nil {
return
}
// Set body based on type
switch body := httpResp.Body.(type) {
case string:
ctx.SetBodyString(body)
case []byte:
ctx.SetBody(body)
case map[string]any, []any, []float64, []string, []int:
// Marshal JSON using a buffer from the pool
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
if err := json.NewEncoder(buf).Encode(body); err == nil {
// Set content type if not already set
if len(ctx.Response.Header.ContentType()) == 0 {
ctx.Response.Header.SetContentType("application/json")
}
ctx.SetBody(buf.Bytes())
} else {
// Fallback
ctx.SetBodyString(fmt.Sprintf("%v", body))
}
default:
// Default to string representation
ctx.SetBodyString(fmt.Sprintf("%v", body))
}
}
// extractCookie grabs cookies from the Lua state
func extractCookie(state *luajit.State) *fasthttp.Cookie {
cookie := fasthttp.AcquireCookie()
// Get name
state.GetField(-1, "name")
if !state.IsString(-1) {
state.Pop(1)
fasthttp.ReleaseCookie(cookie)
return nil // Name is required
}
cookie.SetKey(state.ToString(-1))
state.Pop(1)
// Get value
state.GetField(-1, "value")
if state.IsString(-1) {
cookie.SetValue(state.ToString(-1))
}
state.Pop(1)
// Get path
state.GetField(-1, "path")
if state.IsString(-1) {
cookie.SetPath(state.ToString(-1))
} else {
cookie.SetPath("/") // Default path
}
state.Pop(1)
// Get domain
state.GetField(-1, "domain")
if state.IsString(-1) {
cookie.SetDomain(state.ToString(-1))
}
state.Pop(1)
// Get expires
state.GetField(-1, "expires")
if state.IsNumber(-1) {
expiry := int64(state.ToNumber(-1))
cookie.SetExpire(time.Unix(expiry, 0))
}
state.Pop(1)
// Get max age
state.GetField(-1, "max_age")
if state.IsNumber(-1) {
cookie.SetMaxAge(int(state.ToNumber(-1)))
}
state.Pop(1)
// Get secure
state.GetField(-1, "secure")
if state.IsBoolean(-1) {
cookie.SetSecure(state.ToBoolean(-1))
}
state.Pop(1)
// Get http only
state.GetField(-1, "http_only")
if state.IsBoolean(-1) {
cookie.SetHTTPOnly(state.ToBoolean(-1))
}
state.Pop(1)
return cookie
}
// httpRequest makes an HTTP request and returns the result to Lua
func httpRequest(state *luajit.State) int {
@ -360,372 +578,3 @@ func httpRequest(state *luajit.State) int {
return 1
}
// HTTPModuleInitFunc returns an initializer function for the HTTP module
func HTTPModuleInitFunc() StateInitFunc {
return func(state *luajit.State) error {
// CRITICAL: Register the native Go function first
// This must be done BEFORE any Lua code that references it
if err := state.RegisterGoFunction(httpRequestFuncName, httpRequest); err != nil {
logger.Error("[HTTP Module] Failed to register __http_request function")
logger.ErrorCont("%v", err)
return err
}
// Set up default HTTP client configuration
setupHTTPClientConfig(state)
// Initialize Lua HTTP module
if err := state.DoString(LuaHTTPModule); err != nil {
logger.Error("[HTTP Module] Failed to initialize HTTP module Lua code")
logger.ErrorCont("%v", err)
return err
}
// Verify HTTP client functions are available
verifyHTTPClient(state)
return nil
}
}
// Helper to set up HTTP client config
func setupHTTPClientConfig(state *luajit.State) {
state.NewTable()
state.PushNumber(float64(DefaultHTTPClientConfig.MaxTimeout / time.Second))
state.SetField(-2, "max_timeout")
state.PushNumber(float64(DefaultHTTPClientConfig.DefaultTimeout / time.Second))
state.SetField(-2, "default_timeout")
state.PushNumber(float64(DefaultHTTPClientConfig.MaxResponseSize))
state.SetField(-2, "max_response_size")
state.PushBoolean(DefaultHTTPClientConfig.AllowRemote)
state.SetField(-2, "allow_remote")
state.SetGlobal("__http_client_config")
}
// GetHTTPResponse extracts the HTTP response from Lua state
func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) {
response := NewHTTPResponse()
// Get response table
state.GetGlobal("__http_responses")
if state.IsNil(-1) {
state.Pop(1)
ReleaseResponse(response) // Return unused response to pool
return nil, false
}
// Check for response at thread index
state.PushNumber(1)
state.GetTable(-2)
if state.IsNil(-1) {
state.Pop(2)
ReleaseResponse(response) // Return unused response to pool
return nil, false
}
// Get status
state.GetField(-1, "status")
if state.IsNumber(-1) {
response.Status = int(state.ToNumber(-1))
}
state.Pop(1)
// Get headers
state.GetField(-1, "headers")
if state.IsTable(-1) {
// Iterate through headers table
state.PushNil() // Start iteration
for state.Next(-2) {
// Stack has key at -2 and value at -1
if state.IsString(-2) && state.IsString(-1) {
key := state.ToString(-2)
value := state.ToString(-1)
response.Headers[key] = value
}
state.Pop(1) // Pop value, leave key for next iteration
}
}
state.Pop(1)
// Get cookies
state.GetField(-1, "cookies")
if state.IsTable(-1) {
// Iterate through cookies array
length := state.GetTableLength(-1)
for i := 1; i <= length; i++ {
state.PushNumber(float64(i))
state.GetTable(-2)
if state.IsTable(-1) {
cookie := extractCookie(state)
if cookie != nil {
response.Cookies = append(response.Cookies, cookie)
}
}
state.Pop(1)
}
}
state.Pop(1)
// Clean up
state.Pop(2) // Pop response table and __http_responses
return response, true
}
// ApplyHTTPResponse applies an HTTP response to a fasthttp.RequestCtx
func ApplyHTTPResponse(httpResp *HTTPResponse, ctx *fasthttp.RequestCtx) {
// Set status code
ctx.SetStatusCode(httpResp.Status)
// Set headers
for name, value := range httpResp.Headers {
ctx.Response.Header.Set(name, value)
}
// Set cookies
for _, cookie := range httpResp.Cookies {
ctx.Response.Header.SetCookie(cookie)
}
// Process the body based on its type
if httpResp.Body == nil {
return
}
// Set body based on type
switch body := httpResp.Body.(type) {
case string:
ctx.SetBodyString(body)
case []byte:
ctx.SetBody(body)
case map[string]any, []any, []float64, []string, []int:
// Marshal JSON using a buffer from the pool
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
if err := json.NewEncoder(buf).Encode(body); err == nil {
// Set content type if not already set
if len(ctx.Response.Header.ContentType()) == 0 {
ctx.Response.Header.SetContentType("application/json")
}
ctx.SetBody(buf.Bytes())
} else {
// Fallback
ctx.SetBodyString(fmt.Sprintf("%v", body))
}
default:
// Default to string representation
ctx.SetBodyString(fmt.Sprintf("%v", body))
}
}
// WithHTTPClientConfig creates a runner option to configure the HTTP client
func WithHTTPClientConfig(config HTTPClientConfig) RunnerOption {
return func(r *Runner) {
// Store the config to be applied during initialization
r.AddModule("__http_client_config", map[string]any{
"max_timeout": float64(config.MaxTimeout / time.Second),
"default_timeout": float64(config.DefaultTimeout / time.Second),
"max_response_size": float64(config.MaxResponseSize),
"allow_remote": config.AllowRemote,
})
}
}
// RestrictHTTPToLocalhost is a convenience function to restrict HTTP client
// to localhost connections only
func RestrictHTTPToLocalhost() RunnerOption {
return WithHTTPClientConfig(HTTPClientConfig{
MaxTimeout: DefaultHTTPClientConfig.MaxTimeout,
DefaultTimeout: DefaultHTTPClientConfig.DefaultTimeout,
MaxResponseSize: DefaultHTTPClientConfig.MaxResponseSize,
AllowRemote: false,
})
}
// Verify that HTTP client is properly set up
func verifyHTTPClient(state *luajit.State) {
// Get the client table
state.GetGlobal("http")
if !state.IsTable(-1) {
logger.Warning("[HTTP Module] 'http' is not a table")
state.Pop(1)
return
}
state.GetField(-1, "client")
if !state.IsTable(-1) {
logger.Warning("[HTTP Module] 'http.client' is not a table")
state.Pop(2)
return
}
// Check for get function
state.GetField(-1, "get")
if !state.IsFunction(-1) {
logger.Warning("[HTTP Module] 'http.client.get' is not a function")
}
state.Pop(1)
// Check for the request function
state.GetField(-1, "request")
if !state.IsFunction(-1) {
logger.Warning("[HTTP Module] 'http.client.request' is not a function")
}
state.Pop(3) // Pop request, client, http
}
const LuaHTTPModule = `
-- Table to store response data
__http_responses = {}
-- HTTP module implementation
local http = {
-- Set HTTP status code
set_status = function(code)
if type(code) ~= "number" then
error("http.set_status: status code must be a number", 2)
end
local resp = __http_responses[1] or {}
resp.status = code
__http_responses[1] = resp
end,
-- Set HTTP header
set_header = function(name, value)
if type(name) ~= "string" or type(value) ~= "string" then
error("http.set_header: name and value must be strings", 2)
end
local resp = __http_responses[1] or {}
resp.headers = resp.headers or {}
resp.headers[name] = value
__http_responses[1] = resp
end,
-- Set content type; set_header helper
set_content_type = function(content_type)
http.set_header("Content-Type", content_type)
end,
-- HTTP client submodule
client = {
-- Generic request function
request = function(method, url, body, options)
if type(method) ~= "string" then
error("http.client.request: method must be a string", 2)
end
if type(url) ~= "string" then
error("http.client.request: url must be a string", 2)
end
-- Call native implementation (this is the critical part)
local result = __http_request(method, url, body, options)
return result
end,
-- Simple GET request
get = function(url, options)
return http.client.request("GET", url, nil, options)
end,
-- Simple POST request with automatic content-type
post = function(url, body, options)
options = options or {}
return http.client.request("POST", url, body, options)
end,
-- Simple PUT request with automatic content-type
put = function(url, body, options)
options = options or {}
return http.client.request("PUT", url, body, options)
end,
-- Simple DELETE request
delete = function(url, options)
return http.client.request("DELETE", url, nil, options)
end,
-- Simple PATCH request
patch = function(url, body, options)
options = options or {}
return http.client.request("PATCH", url, body, options)
end,
-- Simple HEAD request
head = function(url, options)
options = options or {}
local old_options = options
options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query}
local response = http.client.request("HEAD", url, nil, options)
return response
end,
-- Simple OPTIONS request
options = function(url, options)
return http.client.request("OPTIONS", url, nil, options)
end,
-- Shorthand function to directly get JSON
get_json = function(url, options)
options = options or {}
local response = http.client.get(url, options)
if response.ok and response.json then
return response.json
end
return nil, response
end,
-- Utility to build a URL with query parameters
build_url = function(base_url, params)
if not params or type(params) ~= "table" then
return base_url
end
local query = {}
for k, v in pairs(params) do
if type(v) == "table" then
for _, item in ipairs(v) do
table.insert(query, k .. "=" .. tostring(item))
end
else
table.insert(query, k .. "=" .. tostring(v))
end
end
if #query > 0 then
if base_url:find("?") then
return base_url .. "&" .. table.concat(query, "&")
else
return base_url .. "?" .. table.concat(query, "&")
end
end
return base_url
end
}
}
-- Install HTTP module
_G.http = http
-- Clear previous responses when executing scripts
local old_execute_script = __execute_script
if old_execute_script then
__execute_script = function(fn, ctx)
-- Clear previous response
__http_responses[1] = nil
-- Execute original function
return old_execute_script(fn, ctx)
end
end
`

View File

@ -0,0 +1,86 @@
package sandbox
import (
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// ModuleFunc is a function that returns a map of module functions
type ModuleFunc func() map[string]luajit.GoFunction
// RegisterModule registers a map of functions as a Lua module
func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.GoFunction) error {
// Create a new table for the module
state.NewTable()
// Add each function to the module table
for fname, f := range funcs {
// Push function name
state.PushString(fname)
// Push function
if err := state.PushGoFunction(f); err != nil {
state.Pop(1) // Pop table
return err
}
// Set table[fname] = f
state.SetTable(-3)
}
// Register the module globally
state.SetGlobal(name)
return nil
}
// CombineInitFuncs combines multiple state initializer functions into one
func CombineInitFuncs(funcs ...func(*luajit.State) error) func(*luajit.State) error {
return func(state *luajit.State) error {
for _, f := range funcs {
if f != nil {
if err := f(state); err != nil {
return err
}
}
}
return nil
}
}
// ModuleInitFunc creates a state initializer that registers multiple modules
func ModuleInitFunc(modules map[string]ModuleFunc) func(*luajit.State) error {
return func(state *luajit.State) error {
for name, moduleFunc := range modules {
if err := RegisterModule(state, name, moduleFunc()); err != nil {
logger.Error("Failed to register module %s: %v", name, err)
return err
}
}
return nil
}
}
// RegisterLuaCode registers a Lua code snippet as a module
func RegisterLuaCode(state *luajit.State, code string) error {
return state.DoString(code)
}
// RegisterLuaCodeInitFunc returns a StateInitFunc that registers Lua code
func RegisterLuaCodeInitFunc(code string) func(*luajit.State) error {
return func(state *luajit.State) error {
return RegisterLuaCode(state, code)
}
}
// RegisterLuaModuleInitFunc returns a StateInitFunc that registers a Lua module
func RegisterLuaModuleInitFunc(name string, code string) func(*luajit.State) error {
return func(state *luajit.State) error {
// Create name = {} global
state.NewTable()
state.SetGlobal(name)
// Then run the module code which will populate it
return state.DoString(code)
}
}

View File

@ -1,4 +1,4 @@
package runner
package sandbox
import (
"fmt"
@ -6,17 +6,44 @@ import (
"github.com/goccy/go-json"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Global bytecode cache to improve performance
var (
sandboxBytecode []byte
bytecodeOnce sync.Once
)
// precompileSandbox compiles the sandbox.lua code to bytecode once
func precompileSandbox() {
tempState := luajit.New()
if tempState == nil {
logger.Error("Failed to create temporary Lua state for bytecode compilation")
return
}
defer tempState.Close()
defer tempState.Cleanup()
var err error
sandboxBytecode, err = tempState.CompileBytecode(sandboxLua, "sandbox.lua")
if err != nil {
logger.Error("Failed to precompile sandbox.lua: %v", err)
} else {
logger.Debug("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(sandboxBytecode))
}
}
// Sandbox provides a secure execution environment for Lua scripts
type Sandbox struct {
modules map[string]any // Custom modules for environment
debug bool // Enable debug output
mu sync.RWMutex // Protects modules
initializers *ModuleInitializers // Module initializers
}
// NewSandbox creates a new sandbox environment
@ -24,6 +51,7 @@ func NewSandbox() *Sandbox {
return &Sandbox{
modules: make(map[string]any, 8), // Pre-allocate with reasonable capacity
debug: false,
initializers: DefaultInitializers(),
}
}
@ -39,7 +67,7 @@ func (s *Sandbox) debugLog(format string, args ...interface{}) {
}
}
// debugLog logs a message if debug mode is enabled
// debugLogCont logs a continuation message if debug mode is enabled
func (s *Sandbox) debugLogCont(format string, args ...interface{}) {
if s.debug {
logger.DebugCont(format, args...)
@ -60,19 +88,27 @@ func (s *Sandbox) Setup(state *luajit.State, stateIndex int) error {
verbose := stateIndex == 0
if verbose {
s.debugLog("is setting up...")
s.debugLog("Setting up sandbox...")
}
// Register modules in the global environment
// Initialize modules with the embedded sandbox code
if err := InitializeAll(state, s.initializers); err != nil {
if verbose {
s.debugLog("Failed to initialize sandbox: %v", err)
}
return err
}
// Register custom modules in the global environment
s.mu.RLock()
for name, module := range s.modules {
if verbose {
s.debugLog("is registering module: %s", name)
s.debugLog("Registering module: %s", name)
}
if err := state.PushValue(module); err != nil {
s.mu.RUnlock()
if verbose {
s.debugLog("failed to register module %s: %v", name, err)
s.debugLog("Failed to register module %s: %v", name, err)
}
return err
}
@ -80,60 +116,8 @@ func (s *Sandbox) Setup(state *luajit.State, stateIndex int) error {
}
s.mu.RUnlock()
// Initialize environment setup
err := state.DoString(`
-- Global tables for response handling
__http_responses = __http_responses or {}
-- Create environment inheriting from _G
function __create_env(ctx)
-- Create environment with metatable inheriting from _G
local env = setmetatable({}, {__index = _G})
-- Add context if provided
if ctx then
env.ctx = ctx
end
-- Add proper require function to this environment
if __setup_require then
__setup_require(env)
end
return env
end
-- Execute script with clean environment
function __execute_script(fn, ctx)
-- Clear previous responses
__http_responses[1] = nil
-- Create environment
local env = __create_env(ctx)
-- Set environment for function
setfenv(fn, env)
-- Execute with protected call
local ok, result = pcall(fn)
if not ok then
error(result, 0)
end
return result
end
`)
if err != nil {
if verbose {
s.debugLog("failed to set up...")
s.debugLogCont("%v", err)
}
return err
}
if verbose {
s.debugLogCont("Complete")
s.debugLogCont("Sandbox setup complete")
}
return nil
}
@ -152,6 +136,14 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]a
return s.OptimizedExecute(state, bytecode, nil)
}
// Context represents execution context for a Lua script
type Context struct {
// Values stores any context values (route params, HTTP request info, etc.)
Values map[string]any
// RequestCtx for HTTP requests
RequestCtx *fasthttp.RequestCtx
}
// OptimizedExecute runs bytecode with a fasthttp context if available
func (s *Sandbox) OptimizedExecute(state *luajit.State, bytecode []byte, ctx *Context) (any, error) {
// Use a buffer from the pool for any string operations

View File

@ -1,4 +1,4 @@
package runner
package sandbox
import (
"crypto/rand"
@ -9,6 +9,20 @@ import (
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// UtilModuleInitFunc returns an initializer for the util module
func UtilModuleInitFunc() func(*luajit.State) error {
return func(state *luajit.State) error {
return RegisterModule(state, "util", UtilModuleFunctions())
}
}
// UtilModuleFunctions returns all functions for the util module
func UtilModuleFunctions() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"generate_token": GenerateToken,
}
}
// GenerateToken creates a cryptographically secure random token
func GenerateToken(s *luajit.State) int {
// Get the length from the Lua arguments (default to 32)
@ -42,17 +56,3 @@ func GenerateToken(s *luajit.State) int {
s.PushString(token)
return 1 // One return value
}
// UtilModuleFunctions returns all functions for the go module
func UtilModuleFunctions() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"generate_token": GenerateToken,
}
}
// UtilModuleInitFunc returns an initializer for the go module
func UtilModuleInitFunc() StateInitFunc {
return func(state *luajit.State) error {
return RegisterModule(state, "util", UtilModuleFunctions())
}
}

View File

@ -0,0 +1,552 @@
--[[
Moonshark Lua Sandbox Environment
This file contains all the Lua code needed for the sandbox environment,
including core modules and utilities. It's designed to be embedded in the
Go binary at build time.
]]--
-- Global tables for execution context
__http_responses = {}
__module_paths = {}
__module_bytecode = {}
__ready_modules = {}
__session_data = {}
__session_id = nil
__session_modified = false
__env_system = {
base_env = {}
}
-- ======================================================================
-- CORE SANDBOX FUNCTIONALITY
-- ======================================================================
-- Create environment inheriting from _G
function __create_env(ctx)
-- Create environment with metatable inheriting from _G
local env = setmetatable({}, {__index = _G})
-- Add context if provided
if ctx then
env.ctx = ctx
end
-- Add proper require function to this environment
if __setup_require then
__setup_require(env)
end
return env
end
-- Execute script with clean environment
function __execute_script(fn, ctx)
-- Clear previous responses
__http_responses[1] = nil
-- Reset session modification flag
__session_modified = false
-- Create environment
local env = __create_env(ctx)
-- Set environment for function
setfenv(fn, env)
-- Execute with protected call
local ok, result = pcall(fn)
if not ok then
error(result, 0)
end
return result
end
-- ======================================================================
-- MODULE LOADING SYSTEM
-- ======================================================================
-- Setup environment-aware require function
function __setup_require(env)
-- Create require function specific to this environment
env.require = function(modname)
-- Check if already loaded
if package.loaded[modname] then
return package.loaded[modname]
end
-- Check preloaded modules
if __ready_modules[modname] then
local loader = package.preload[modname]
if loader then
-- Set environment for loader
setfenv(loader, env)
-- Execute and store result
local result = loader()
if result == nil then
result = true
end
package.loaded[modname] = result
return result
end
end
-- Direct file load as fallback
if __module_paths[modname] then
local path = __module_paths[modname]
local chunk, err = loadfile(path)
if chunk then
setfenv(chunk, env)
local result = chunk()
if result == nil then
result = true
end
package.loaded[modname] = result
return result
end
end
-- Full path search as last resort
local errors = {}
for path in package.path:gmatch("[^;]+") do
local file_path = path:gsub("?", modname:gsub("%.", "/"))
local chunk, err = loadfile(file_path)
if chunk then
setfenv(chunk, env)
local result = chunk()
if result == nil then
result = true
end
package.loaded[modname] = result
return result
end
table.insert(errors, "\tno file '" .. file_path .. "'")
end
error("module '" .. modname .. "' not found:\n" .. table.concat(errors, "\n"), 2)
end
return env
end
-- ======================================================================
-- HTTP MODULE
-- ======================================================================
-- HTTP module implementation
local http = {
-- Set HTTP status code
set_status = function(code)
if type(code) ~= "number" then
error("http.set_status: status code must be a number", 2)
end
local resp = __http_responses[1] or {}
resp.status = code
__http_responses[1] = resp
end,
-- Set HTTP header
set_header = function(name, value)
if type(name) ~= "string" or type(value) ~= "string" then
error("http.set_header: name and value must be strings", 2)
end
local resp = __http_responses[1] or {}
resp.headers = resp.headers or {}
resp.headers[name] = value
__http_responses[1] = resp
end,
-- Set content type; set_header helper
set_content_type = function(content_type)
http.set_header("Content-Type", content_type)
end,
-- HTTP client submodule
client = {
-- Generic request function
request = function(method, url, body, options)
if type(method) ~= "string" then
error("http.client.request: method must be a string", 2)
end
if type(url) ~= "string" then
error("http.client.request: url must be a string", 2)
end
-- Call native implementation
local result = __http_request(method, url, body, options)
return result
end,
-- Simple GET request
get = function(url, options)
return http.client.request("GET", url, nil, options)
end,
-- Simple POST request with automatic content-type
post = function(url, body, options)
options = options or {}
return http.client.request("POST", url, body, options)
end,
-- Simple PUT request with automatic content-type
put = function(url, body, options)
options = options or {}
return http.client.request("PUT", url, body, options)
end,
-- Simple DELETE request
delete = function(url, options)
return http.client.request("DELETE", url, nil, options)
end,
-- Simple PATCH request
patch = function(url, body, options)
options = options or {}
return http.client.request("PATCH", url, body, options)
end,
-- Simple HEAD request
head = function(url, options)
options = options or {}
local old_options = options
options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query}
local response = http.client.request("HEAD", url, nil, options)
return response
end,
-- Simple OPTIONS request
options = function(url, options)
return http.client.request("OPTIONS", url, nil, options)
end,
-- Shorthand function to directly get JSON
get_json = function(url, options)
options = options or {}
local response = http.client.get(url, options)
if response.ok and response.json then
return response.json
end
return nil, response
end,
-- Utility to build a URL with query parameters
build_url = function(base_url, params)
if not params or type(params) ~= "table" then
return base_url
end
local query = {}
for k, v in pairs(params) do
if type(v) == "table" then
for _, item in ipairs(v) do
table.insert(query, k .. "=" .. tostring(item))
end
else
table.insert(query, k .. "=" .. tostring(v))
end
end
if #query > 0 then
if base_url:find("?") then
return base_url .. "&" .. table.concat(query, "&")
else
return base_url .. "?" .. table.concat(query, "&")
end
end
return base_url
end
}
}
-- ======================================================================
-- COOKIE MODULE
-- ======================================================================
-- Cookie module implementation
local cookie = {
-- Set a cookie
set = function(name, value, options, ...)
if type(name) ~= "string" then
error("cookie.set: name must be a string", 2)
end
-- Get or create response
local resp = __http_responses[1] or {}
resp.cookies = resp.cookies or {}
__http_responses[1] = resp
-- Handle options as table or legacy params
local opts = {}
if type(options) == "table" then
opts = options
elseif options ~= nil then
-- Legacy support: options is actually 'expires'
opts.expires = options
-- Check for other legacy params (4th-7th args)
local args = {...}
if args[1] then opts.path = args[1] end
if args[2] then opts.domain = args[2] end
if args[3] then opts.secure = args[3] end
if args[4] ~= nil then opts.http_only = args[4] end
end
-- Create cookie table
local cookie = {
name = name,
value = value or "",
path = opts.path or "/",
domain = opts.domain
}
-- Handle expiry
if opts.expires then
if type(opts.expires) == "number" then
if opts.expires > 0 then
cookie.max_age = opts.expires
local now = os.time()
cookie.expires = now + opts.expires
elseif opts.expires < 0 then
cookie.expires = 1
cookie.max_age = 0
else
-- opts.expires == 0: Session cookie
-- Do nothing (omitting both expires and max-age creates a session cookie)
end
end
end
-- Security flags
cookie.secure = (opts.secure ~= false)
cookie.http_only = (opts.http_only ~= false)
-- Store in cookies table
local n = #resp.cookies + 1
resp.cookies[n] = cookie
return true
end,
-- Get a cookie value
get = function(name)
if type(name) ~= "string" then
error("cookie.get: name must be a string", 2)
end
-- Access values directly from current environment
local env = getfenv(2)
-- Check if context exists and has cookies
if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then
return tostring(env.ctx.cookies[name])
end
return nil
end,
-- Remove a cookie
remove = function(name, path, domain)
if type(name) ~= "string" then
error("cookie.remove: name must be a string", 2)
end
-- Create an expired cookie
return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain})
end
}
-- ======================================================================
-- SESSION MODULE
-- ======================================================================
-- Session module implementation
local session = {
-- Get a session value
get = function(key)
if type(key) ~= "string" then
error("session.get: key must be a string", 2)
end
if __session_data and __session_data[key] then
return __session_data[key]
end
return nil
end,
-- Set a session value
set = function(key, value)
if type(key) ~= "string" then
error("session.set: key must be a string", 2)
end
-- Ensure session data table exists
__session_data = __session_data or {}
-- Store value
__session_data[key] = value
-- Mark session as modified
__session_modified = true
return true
end,
-- Delete a session value
delete = function(key)
if type(key) ~= "string" then
error("session.delete: key must be a string", 2)
end
if __session_data then
__session_data[key] = nil
__session_modified = true
end
return true
end,
-- Clear all session data
clear = function()
__session_data = {}
__session_modified = true
return true
end,
-- Get the session ID
get_id = function()
return __session_id or nil
end,
-- Get all session data
get_all = function()
local result = {}
for k, v in pairs(__session_data or {}) do
result[k] = v
end
return result
end,
-- Check if session has a key
has = function(key)
if type(key) ~= "string" then
error("session.has: key must be a string", 2)
end
return __session_data and __session_data[key] ~= nil
end
}
-- ======================================================================
-- CSRF MODULE
-- ======================================================================
-- CSRF protection module
local csrf = {
-- Session key where the token is stored
TOKEN_KEY = "_csrf_token",
-- Default form field name
DEFAULT_FIELD = "csrf",
-- Generate a new CSRF token and store it in the session
generate = function(length)
-- Default length is 32 characters
length = length or 32
if length < 16 then
-- Enforce minimum security
length = 16
end
-- Check if we have a session module
if not session then
error("CSRF protection requires the session module", 2)
end
local token = util.generate_token(length)
session.set(csrf.TOKEN_KEY, token)
return token
end,
-- Get the current token or generate a new one
token = function()
-- Get from session if exists
local token = session.get(csrf.TOKEN_KEY)
-- Generate if needed
if not token then
token = csrf.generate()
end
return token
end,
-- Generate a hidden form field with the CSRF token
field = function(field_name)
field_name = field_name or csrf.DEFAULT_FIELD
local token = csrf.token()
return string.format('<input type="hidden" name="%s" value="%s">', field_name, token)
end,
-- Verify a given token against the session token
verify = function(token, field_name)
field_name = field_name or csrf.DEFAULT_FIELD
local env = getfenv(2)
local form = nil
if env.ctx and env.ctx.form then
form = env.ctx.form
else
return false
end
token = token or form[field_name]
if not token then
return false
end
local session_token = session.get(csrf.TOKEN_KEY)
if not session_token then
return false
end
-- Constant-time comparison to prevent timing attacks
-- This is safe since Lua strings are immutable
if #token ~= #session_token then
return false
end
local result = true
for i = 1, #token do
if token:sub(i, i) ~= session_token:sub(i, i) then
result = false
-- Don't break early - continue to prevent timing attacks
end
end
return result
end
}
-- ======================================================================
-- REGISTER MODULES GLOBALLY
-- ======================================================================
-- Install modules in global scope
_G.http = http
_G.cookie = cookie
_G.session = session
_G.csrf = csrf
-- Register modules in sandbox base environment
__env_system.base_env.http = http
__env_system.base_env.cookie = cookie
__env_system.base_env.session = session
__env_system.base_env.csrf = csrf

View File

@ -87,7 +87,6 @@ func (sm *SessionManager) GetSession(id string) *Session {
func (sm *SessionManager) CreateSession() *Session {
id := generateSessionID()
// Create new session
session := NewSession(id)
data, _ := json.Marshal(session)
sm.cache.Set([]byte(id), data)
@ -95,6 +94,12 @@ func (sm *SessionManager) CreateSession() *Session {
return session
}
// SaveSession persists a session back to the cache
func (sm *SessionManager) SaveSession(session *Session) {
data, _ := json.Marshal(session)
sm.cache.Set([]byte(session.ID), data)
}
// DestroySession removes a session
func (sm *SessionManager) DestroySession(id string) {
sm.cache.Del([]byte(id))

View File

@ -18,13 +18,13 @@ var (
// Session stores data for a single user session
type Session struct {
ID string
Data map[string]any
CreatedAt time.Time
UpdatedAt time.Time
mu sync.RWMutex // Protect concurrent access to Data
maxValueSize int // Maximum size of individual values in bytes
totalDataSize int // Track total size of all data
ID string `json:"id"`
Data map[string]any `json:"data"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
mu sync.RWMutex `json:"-"`
maxValueSize int `json:"max_value_size"`
totalDataSize int `json:"total_data_size"`
}
// NewSession creates a new session with the given ID