Moonshark/core/runner/sandbox/Sandbox.go
2025-04-09 16:19:51 -05:00

372 lines
10 KiB
Go

package sandbox
import (
"fmt"
"sync"
"github.com/goccy/go-json"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
luaCtx "Moonshark/core/runner/context"
"Moonshark/core/sessions"
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Global bytecode cache to improve performance
var (
sandboxBytecode []byte
bytecodeOnce sync.Once
)
// precompileSandbox compiles the sandbox.lua code to bytecode once
func precompileSandbox() {
tempState := luajit.New()
if tempState == nil {
logger.Error("Failed to create temporary Lua state for bytecode compilation")
return
}
defer tempState.Close()
defer tempState.Cleanup()
var err error
sandboxBytecode, err = tempState.CompileBytecode(sandboxLua, "sandbox.lua")
if err != nil {
logger.Error("Failed to precompile sandbox.lua: %v", err)
} else {
logger.Debug("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(sandboxBytecode))
}
}
// Sandbox provides a secure execution environment for Lua scripts
type Sandbox struct {
modules map[string]any // Custom modules for environment
debug bool // Enable debug output
mu sync.RWMutex // Protects modules
initializers *ModuleInitializers // Module initializers
}
// NewSandbox creates a new sandbox environment
func NewSandbox() *Sandbox {
return &Sandbox{
modules: make(map[string]any, 8), // Pre-allocate with reasonable capacity
debug: false,
initializers: DefaultInitializers(),
}
}
// EnableDebug turns on debug logging
func (s *Sandbox) EnableDebug() {
s.debug = true
}
// debugLog logs a message if debug mode is enabled
func (s *Sandbox) debugLog(format string, args ...interface{}) {
if s.debug {
logger.Debug("Sandbox "+format, args...)
}
}
// debugLogCont logs a continuation message if debug mode is enabled
func (s *Sandbox) debugLogCont(format string, args ...interface{}) {
if s.debug {
logger.DebugCont(format, args...)
}
}
// AddModule adds a module to the sandbox environment
func (s *Sandbox) AddModule(name string, module any) {
s.mu.Lock()
defer s.mu.Unlock()
s.modules[name] = module
s.debugLog("Added module: %s", name)
}
// Setup initializes the sandbox in a Lua state
func (s *Sandbox) Setup(state *luajit.State, stateIndex int) error {
verbose := stateIndex == 0
if verbose {
s.debugLog("Setting up sandbox...")
}
// Initialize modules with the embedded sandbox code
if err := InitializeAll(state, s.initializers); err != nil {
if verbose {
s.debugLog("Failed to initialize sandbox: %v", err)
}
return err
}
// Register custom modules in the global environment
s.mu.RLock()
for name, module := range s.modules {
if verbose {
s.debugLog("Registering module: %s", name)
}
if err := state.PushValue(module); err != nil {
s.mu.RUnlock()
if verbose {
s.debugLog("Failed to register module %s: %v", name, err)
}
return err
}
state.SetGlobal(name)
}
s.mu.RUnlock()
if verbose {
s.debugLogCont("Sandbox setup complete")
}
return nil
}
// Execute runs bytecode in the sandbox
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]any) (any, error) {
// Create a temporary context if we only have a map
if ctx != nil {
tempCtx := &luaCtx.Context{
Values: ctx,
}
return s.OptimizedExecute(state, bytecode, tempCtx)
}
// Just pass nil through if we have no context
return s.OptimizedExecute(state, bytecode, nil)
}
// OptimizedExecute runs bytecode with a fasthttp context if available
func (s *Sandbox) OptimizedExecute(state *luajit.State, bytecode []byte, ctx *luaCtx.Context) (any, error) {
// Use a buffer from the pool for any string operations
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
// Load bytecode
if err := state.LoadBytecode(bytecode, "script"); err != nil {
s.debugLog("Failed to load bytecode: %v", err)
return nil, fmt.Errorf("failed to load script: %w", err)
}
// Prepare context values
var ctxValues map[string]any
if ctx != nil {
ctxValues = ctx.Values
} else {
ctxValues = nil
}
// Initialize session tracking in Lua
if err := state.DoString("__session_data = {}; __session_modified = false"); err != nil {
s.debugLog("Failed to initialize session data: %v", err)
}
// Load session data if available
if ctx != nil && ctx.Session != nil {
// Set session ID in Lua
sessionIDCode := fmt.Sprintf("__session_id = %q", ctx.Session.ID)
if err := state.DoString(sessionIDCode); err != nil {
s.debugLog("Failed to set session ID: %v", err)
}
// Get session data and populate Lua table
state.GetGlobal("__session_data")
if state.IsTable(-1) {
sessionData := ctx.Session.GetAll()
for k, v := range sessionData {
state.PushString(k)
if err := state.PushValue(v); err != nil {
s.debugLog("Failed to push session value %s: %v", k, err)
continue
}
state.SetTable(-3)
}
}
state.Pop(1) // Pop __session_data
}
// Prepare context table
if ctxValues != nil {
state.CreateTable(0, len(ctxValues))
for k, v := range ctxValues {
state.PushString(k)
if err := state.PushValue(v); err != nil {
state.Pop(2) // Pop key and table
s.debugLog("Failed to push context value %s: %v", k, err)
return nil, fmt.Errorf("failed to prepare context: %w", err)
}
state.SetTable(-3)
}
} else {
state.PushNil() // No context
}
// Get execution function
state.GetGlobal("__execute_script")
if !state.IsFunction(-1) {
state.Pop(2) // Pop context and non-function
s.debugLog("__execute_script is not a function")
return nil, fmt.Errorf("sandbox execution function not found")
}
// Stack setup for call: __execute_script, bytecode function, context
state.PushCopy(-3) // bytecode function (copy from -3)
state.PushCopy(-3) // context (copy from -3)
// Clean up duplicate references
state.Remove(-5) // Remove original bytecode function
state.Remove(-4) // Remove original context
// Call with 2 args (function, context), 1 result
if err := state.Call(2, 1); err != nil {
s.debugLog("Execution failed: %v", err)
return nil, fmt.Errorf("script execution failed: %w", err)
}
// Get result
result, err := state.ToValue(-1)
state.Pop(1) // Pop result
// Extract session data if it was modified
if ctx != nil && ctx.Session != nil {
// Check if session was modified
state.GetGlobal("__session_modified")
if state.IsBoolean(-1) && state.ToBoolean(-1) {
ctx.SessionModified = true
// Extract session data
state.GetGlobal("__session_data")
if state.IsTable(-1) {
// Clear existing data and extract new data from Lua
sessionData := make(map[string]any)
// Extract new session data
state.PushNil() // Start iteration
for state.Next(-2) {
// Stack now has key at -2 and value at -1
if state.IsString(-2) {
key := state.ToString(-2)
value, err := state.ToValue(-1)
if err == nil {
sessionData[key] = value
}
}
state.Pop(1) // Pop value, leave key for next iteration
}
// Update session with the new data
for k, v := range sessionData {
if err := ctx.Session.Set(k, v); err != nil {
s.debugLog("Failed to set session value %s: %v", k, err)
}
}
}
state.Pop(1) // Pop __session_data
}
state.Pop(1) // Pop __session_modified
}
// Check for HTTP response
httpResponse, hasResponse := GetHTTPResponse(state)
if hasResponse {
// Add the script result as the response body
httpResponse.Body = result
// Mark session as modified if needed
if ctx != nil && ctx.SessionModified {
httpResponse.SessionModified = true
}
// If we have a fasthttp context, apply the response directly
if ctx != nil && ctx.RequestCtx != nil {
// If session was modified, save it
if ctx.SessionModified && ctx.Session != nil {
// Save session and set cookie if needed
sessions.GlobalSessionManager.SaveSession(ctx.Session)
// Add session cookie to the response
cookieOpts := sessions.GlobalSessionManager.CookieOptions()
cookie := fasthttp.AcquireCookie()
cookie.SetKey(cookieOpts["name"].(string))
cookie.SetValue(ctx.Session.ID)
cookie.SetPath(cookieOpts["path"].(string))
if domain, ok := cookieOpts["domain"].(string); ok && domain != "" {
cookie.SetDomain(domain)
}
if maxAge, ok := cookieOpts["max_age"].(int); ok && maxAge > 0 {
cookie.SetMaxAge(maxAge)
}
cookie.SetSecure(cookieOpts["secure"].(bool))
cookie.SetHTTPOnly(cookieOpts["http_only"].(bool))
// Add to response cookies
httpResponse.Cookies = append(httpResponse.Cookies, cookie)
}
ApplyHTTPResponse(httpResponse, ctx.RequestCtx)
ReleaseResponse(httpResponse)
return nil, nil // No need to return response object
}
return httpResponse, nil
}
// If we have a fasthttp context and the result needs to be written directly
if ctx != nil && ctx.RequestCtx != nil && (result != nil) {
// For direct HTTP responses
switch r := result.(type) {
case string:
ctx.RequestCtx.SetBodyString(r)
case []byte:
ctx.RequestCtx.SetBody(r)
case map[string]any, []any:
// JSON response
ctx.RequestCtx.Response.Header.SetContentType("application/json")
if err := json.NewEncoder(buf).Encode(r); err == nil {
ctx.RequestCtx.SetBody(buf.Bytes())
} else {
ctx.RequestCtx.SetBodyString(fmt.Sprintf("%v", r))
}
default:
// Default string conversion
ctx.RequestCtx.SetBodyString(fmt.Sprintf("%v", r))
}
// Handle session if modified
if ctx.SessionModified && ctx.Session != nil {
// Save session
sessions.GlobalSessionManager.SaveSession(ctx.Session)
// Add session cookie
cookieOpts := sessions.GlobalSessionManager.CookieOptions()
cookie := fasthttp.AcquireCookie()
cookie.SetKey(cookieOpts["name"].(string))
cookie.SetValue(ctx.Session.ID)
cookie.SetPath(cookieOpts["path"].(string))
if domain, ok := cookieOpts["domain"].(string); ok && domain != "" {
cookie.SetDomain(domain)
}
if maxAge, ok := cookieOpts["max_age"].(int); ok && maxAge > 0 {
cookie.SetMaxAge(maxAge)
}
cookie.SetSecure(cookieOpts["secure"].(bool))
cookie.SetHTTPOnly(cookieOpts["http_only"].(bool))
// Add to response
ctx.RequestCtx.Response.Header.SetCookie(cookie)
}
return nil, nil
}
return result, err
}