massive rewrite 1

This commit is contained in:
Sky Johnson 2025-04-09 19:03:35 -05:00
parent f4b1e5fad7
commit 5ebcd97662
23 changed files with 1446 additions and 2683 deletions

View File

@ -184,8 +184,6 @@ func (s *Moonshark) initRunner() error {
runnerOpts := []runner.RunnerOption{ runnerOpts := []runner.RunnerOption{
runner.WithPoolSize(s.Config.Runner.PoolSize), runner.WithPoolSize(s.Config.Runner.PoolSize),
runner.WithLibDirs(s.Config.Dirs.Libs...), runner.WithLibDirs(s.Config.Dirs.Libs...),
runner.WithSessionManager(sessionManager),
http.WithCSRFProtection(),
} }
// Add debug option conditionally // Add debug option conditionally

View File

@ -2,18 +2,19 @@ package http
import ( import (
"Moonshark/core/runner" "Moonshark/core/runner"
luaCtx "Moonshark/core/runner/context"
"Moonshark/core/utils" "Moonshark/core/utils"
"Moonshark/core/utils/logger" "Moonshark/core/utils/logger"
"crypto/subtle" "crypto/subtle"
"errors"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
) )
// Error for CSRF validation failure
var ErrCSRFValidationFailed = errors.New("CSRF token validation failed")
// ValidateCSRFToken checks if the CSRF token is valid for a request // ValidateCSRFToken checks if the CSRF token is valid for a request
func ValidateCSRFToken(state *luajit.State, ctx *luaCtx.Context) bool { func ValidateCSRFToken(ctx *runner.Context) bool {
// Only validate for form submissions // Only validate for form submissions
method, ok := ctx.Get("method").(string) method, ok := ctx.Get("method").(string)
if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") { if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") {
@ -34,87 +35,23 @@ func ValidateCSRFToken(state *luajit.State, ctx *luaCtx.Context) bool {
return false return false
} }
// Get session token // Get token from session
state.GetGlobal("session") sessionData := ctx.SessionData
if state.IsNil(-1) { if sessionData == nil {
state.Pop(1) logger.Warning("CSRF validation failed: no session data")
logger.Warning("CSRF validation failed: session module not available")
return false return false
} }
state.GetField(-1, "get") sessionToken, ok := sessionData["_csrf_token"].(string)
if !state.IsFunction(-1) { if !ok || sessionToken == "" {
state.Pop(2)
logger.Warning("CSRF validation failed: session.get not available")
return false
}
state.PushCopy(-1) // Duplicate function
state.PushString("_csrf_token")
if err := state.Call(1, 1); err != nil {
state.Pop(3) // Pop error, function and session table
logger.Warning("CSRF validation failed: %v", err)
return false
}
if state.IsNil(-1) {
state.Pop(3) // Pop nil, function and session table
logger.Warning("CSRF validation failed: no token in session") logger.Warning("CSRF validation failed: no token in session")
return false return false
} }
sessionToken := state.ToString(-1)
state.Pop(3) // Pop token, function and session table
// Constant-time comparison to prevent timing attacks // Constant-time comparison to prevent timing attacks
return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1 return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1
} }
// WithCSRFProtection creates a runner option to add CSRF protection
func WithCSRFProtection() runner.RunnerOption {
return func(r *runner.Runner) {
r.AddInitHook(func(state *luajit.State, ctx *luaCtx.Context) error {
// Get request method
method, ok := ctx.Get("method").(string)
if !ok {
return nil
}
// Only validate for form submissions
if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" {
return nil
}
// Check for form data
form, ok := ctx.Get("form").(map[string]any)
if !ok || form == nil {
return nil
}
// Validate CSRF token
if !ValidateCSRFToken(state, ctx) {
return ErrCSRFValidationFailed
}
return nil
})
}
}
// Error for CSRF validation failure
var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"}
// CSRFError represents a CSRF validation error
type CSRFError struct {
message string
}
// Error implements the error interface
func (e *CSRFError) Error() string {
return e.message
}
// HandleCSRFError handles a CSRF validation error // HandleCSRFError handles a CSRF validation error
func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) { func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) {
method := string(ctx.Method()) method := string(ctx.Method())
@ -129,3 +66,39 @@ func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig
errorHTML := utils.ForbiddenPage(errorConfig, path, errorMsg) errorHTML := utils.ForbiddenPage(errorConfig, path, errorMsg)
ctx.SetBody([]byte(errorHTML)) ctx.SetBody([]byte(errorHTML))
} }
// GenerateCSRFToken creates a new CSRF token and stores it in the session
func GenerateCSRFToken(ctx *runner.Context, length int) (string, error) {
if length < 16 {
length = 16 // Minimum token length for security
}
// Create secure random token
token, err := GenerateSecureToken(length)
if err != nil {
return "", err
}
// Store token in session
ctx.SessionData["_csrf_token"] = token
return token, nil
}
// GetCSRFToken retrieves the current CSRF token or generates a new one
func GetCSRFToken(ctx *runner.Context) (string, error) {
// Check if token already exists in session
if token, ok := ctx.SessionData["_csrf_token"].(string); ok && token != "" {
return token, nil
}
// Generate new token
return GenerateCSRFToken(ctx, 32)
}
// CSRFMiddleware validates CSRF tokens for state-changing requests
func CSRFMiddleware(ctx *runner.Context) error {
if !ValidateCSRFToken(ctx) {
return ErrCSRFValidationFailed
}
return nil
}

View File

@ -1,116 +0,0 @@
package http
import (
"errors"
"mime/multipart"
"strings"
"github.com/valyala/fasthttp"
)
// Maximum form parse size (16MB)
const maxFormSize = 16 << 20
// Common errors
var (
ErrFormSizeTooLarge = errors.New("form size too large")
ErrInvalidFormType = errors.New("invalid form content type")
)
// ParseForm parses a POST request body into a map of values
// Supports both application/x-www-form-urlencoded and multipart/form-data content types
func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
// Only handle POST, PUT, PATCH
method := string(ctx.Method())
if method != "POST" && method != "PUT" && method != "PATCH" {
return make(map[string]any), nil
}
// Check content type
contentType := string(ctx.Request.Header.ContentType())
if contentType == "" {
return make(map[string]any), nil
}
result := make(map[string]any)
// Check for content length to prevent DOS
if len(ctx.Request.Body()) > maxFormSize {
return nil, ErrFormSizeTooLarge
}
// Handle by content type
if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") {
return parseURLEncodedForm(ctx)
} else if strings.HasPrefix(contentType, "multipart/form-data") {
return parseMultipartForm(ctx)
}
// Unrecognized content type
return result, nil
}
// parseURLEncodedForm handles application/x-www-form-urlencoded forms
func parseURLEncodedForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
result := make(map[string]any)
// Process form values directly from PostArgs()
ctx.PostArgs().VisitAll(func(key, value []byte) {
keyStr := string(key)
valStr := string(value)
// Check if we already have this key
if existing, ok := result[keyStr]; ok {
// If it's already a slice, append
if existingSlice, ok := existing.([]string); ok {
result[keyStr] = append(existingSlice, valStr)
} else {
// Convert to slice and append
result[keyStr] = []string{existing.(string), valStr}
}
} else {
// New key
result[keyStr] = valStr
}
})
return result, nil
}
// parseMultipartForm handles multipart/form-data forms
func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
result := make(map[string]any)
// Parse multipart form
form, err := ctx.MultipartForm()
if err != nil {
if err == multipart.ErrMessageTooLarge || strings.Contains(err.Error(), "too large") {
return nil, ErrFormSizeTooLarge
}
return nil, err
}
// Process form values
for key, values := range form.Value {
if len(values) == 1 {
// Single value
result[key] = values[0]
} else if len(values) > 1 {
// Multiple values - store as string slice
strValues := make([]string, len(values))
copy(strValues, values)
result[key] = strValues
}
}
// We don't handle file uploads here - could be extended in the future
// if needed to support file uploads to Lua
return result, nil
}
// Usage:
// After parsing the form with ParseForm, you can add it to the context with:
// ctx.Set("form", formData)
//
// This makes the form data accessible in Lua as ctx.form.field_name

View File

@ -1,43 +0,0 @@
package http
import (
"time"
"Moonshark/core/utils/logger"
)
// StatusColors for different status code ranges
const (
colorGreen = "\033[32m" // 2xx - Success
colorCyan = "\033[36m" // 3xx - Redirection
colorYellow = "\033[33m" // 4xx - Client Errors
colorRed = "\033[31m" // 5xx - Server Errors
colorReset = "\033[0m" // Reset color
colorGray = "\033[90m"
)
// LogRequest logs an HTTP request with custom formatting
func LogRequest(statusCode int, method, path string, duration time.Duration) {
statusColor := getStatusColor(statusCode)
// Use the logger's raw message writer to bypass the standard format
logger.LogRaw("%s%s%s %s%d %s%s %s %s(%v)%s",
colorGray, time.Now().Format(logger.TimeFormat()), colorReset,
statusColor, statusCode, method, colorReset, path, colorGray, duration, colorReset)
}
// getStatusColor returns the ANSI color code for a status code
func getStatusColor(code int) string {
switch {
case code >= 200 && code < 300:
return colorGreen
case code >= 300 && code < 400:
return colorCyan
case code >= 400 && code < 500:
return colorYellow
case code >= 500:
return colorRed
default:
return ""
}
}

View File

@ -1,43 +0,0 @@
package http
import (
"github.com/valyala/fasthttp"
)
// QueryToLua converts HTTP query parameters to a map that can be used with LuaJIT.
// Single value parameters are stored as strings.
// Multi-value parameters are converted to []any arrays.
func QueryToLua(ctx *fasthttp.RequestCtx) map[string]any {
result := make(map[string]any)
// Use a map to track keys that have multiple values
multiValueKeys := make(map[string]bool)
// Process all query args
ctx.QueryArgs().VisitAll(func(key, value []byte) {
keyStr := string(key)
valStr := string(value)
if _, exists := result[keyStr]; exists {
// This key already exists, convert to array if not already
if !multiValueKeys[keyStr] {
// First duplicate, convert existing value to array
multiValueKeys[keyStr] = true
result[keyStr] = []any{result[keyStr], valStr}
} else {
// Already an array, append
result[keyStr] = append(result[keyStr].([]any), valStr)
}
} else {
// New key
result[keyStr] = valStr
}
})
// If we don't have any query parameters, return empty map
if len(result) == 0 {
return make(map[string]any)
}
return result
}

View File

@ -2,21 +2,17 @@ package http
import ( import (
"context" "context"
"fmt" "errors"
"strings"
"time" "time"
"Moonshark/core/metadata" "Moonshark/core/metadata"
"Moonshark/core/routers" "Moonshark/core/routers"
"Moonshark/core/runner" "Moonshark/core/runner"
luaCtx "Moonshark/core/runner/context"
"Moonshark/core/runner/sandbox"
"Moonshark/core/sessions" "Moonshark/core/sessions"
"Moonshark/core/utils" "Moonshark/core/utils"
"Moonshark/core/utils/config" "Moonshark/core/utils/config"
"Moonshark/core/utils/logger" "Moonshark/core/utils/logger"
"github.com/goccy/go-json"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -29,12 +25,14 @@ type Server struct {
loggingEnabled bool loggingEnabled bool
debugMode bool debugMode bool
config *config.Config config *config.Config
sessionManager *sessions.SessionManager
errorConfig utils.ErrorPageConfig errorConfig utils.ErrorPageConfig
} }
// New creates a new HTTP server with optimized connection settings // New creates a new HTTP server with optimized connection settings
func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runner *runner.Runner, func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter,
loggingEnabled bool, debugMode bool, overrideDir string, config *config.Config) *Server { runner *runner.Runner, loggingEnabled bool, debugMode bool,
overrideDir string, config *config.Config) *Server {
server := &Server{ server := &Server{
luaRouter: luaRouter, luaRouter: luaRouter,
@ -43,6 +41,7 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runne
loggingEnabled: loggingEnabled, loggingEnabled: loggingEnabled,
debugMode: debugMode, debugMode: debugMode,
config: config, config: config,
sessionManager: sessions.GlobalSessionManager,
errorConfig: utils.ErrorPageConfig{ errorConfig: utils.ErrorPageConfig{
OverrideDir: overrideDir, OverrideDir: overrideDir,
DebugMode: debugMode, DebugMode: debugMode,
@ -55,7 +54,7 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runne
Name: "Moonshark/" + metadata.Version, Name: "Moonshark/" + metadata.Version,
ReadTimeout: 30 * time.Second, ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second,
MaxRequestBodySize: 16 << 20, // 16MB - consistent with Forms.go MaxRequestBodySize: 16 << 20, // 16MB
DisableKeepalive: false, DisableKeepalive: false,
TCPKeepalive: true, TCPKeepalive: true,
TCPKeepalivePeriod: 60 * time.Second, TCPKeepalivePeriod: 60 * time.Second,
@ -99,7 +98,7 @@ func (s *Server) handleRequest(ctx *fasthttp.RequestCtx) {
// Process the request // Process the request
s.processRequest(ctx) s.processRequest(ctx)
// Log the request with our custom format // Log the request
if s.loggingEnabled { if s.loggingEnabled {
duration := time.Since(start) duration := time.Since(start)
LogRequest(ctx.Response.StatusCode(), method, path, duration) LogRequest(ctx.Response.StatusCode(), method, path, duration)
@ -153,48 +152,25 @@ func (s *Server) processRequest(ctx *fasthttp.RequestCtx) {
ctx.SetBody([]byte(utils.NotFoundPage(s.errorConfig, path))) ctx.SetBody([]byte(utils.NotFoundPage(s.errorConfig, path)))
} }
// HandleMethodNotAllowed responds with a 405 Method Not Allowed error
func HandleMethodNotAllowed(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) {
path := string(ctx.Path())
ctx.SetContentType("text/html; charset=utf-8")
ctx.SetStatusCode(fasthttp.StatusMethodNotAllowed)
ctx.SetBody([]byte(utils.MethodNotAllowedPage(errorConfig, path)))
}
// handleLuaRoute executes a Lua route // handleLuaRoute executes a Lua route
// Updated handleLuaRoute function to handle sessions
func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scriptPath string, params *routers.Params) { func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scriptPath string, params *routers.Params) {
luaCtx := luaCtx.NewHTTPContext(ctx) // Create context for Lua execution
luaCtx := runner.NewHTTPContext(ctx)
defer luaCtx.Release() defer luaCtx.Release()
method := string(ctx.Method()) method := string(ctx.Method())
path := string(ctx.Path()) path := string(ctx.Path())
host := string(ctx.Host()) host := string(ctx.Host())
// Set up context // Set up additional context values
luaCtx.Set("method", method) luaCtx.Set("method", method)
luaCtx.Set("path", path) luaCtx.Set("path", path)
luaCtx.Set("host", host) luaCtx.Set("host", host)
// Headers // Initialize session
headerMap := make(map[string]any) session := s.sessionManager.GetSessionFromRequest(ctx)
ctx.Request.Header.VisitAll(func(key, value []byte) { luaCtx.SessionID = session.ID
headerMap[string(key)] = string(value) luaCtx.SessionData = session.GetAll()
})
luaCtx.Set("headers", headerMap)
// Cookies
cookieMap := make(map[string]any)
ctx.Request.Header.VisitAllCookie(func(key, value []byte) {
cookieMap[string(key)] = string(value)
})
if len(cookieMap) > 0 {
luaCtx.Set("cookies", cookieMap)
luaCtx.Set("_request_cookies", cookieMap) // For backward compatibility
} else {
luaCtx.Set("cookies", make(map[string]any))
luaCtx.Set("_request_cookies", make(map[string]any))
}
// URL parameters // URL parameters
if params.Count > 0 { if params.Count > 0 {
@ -207,11 +183,7 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
luaCtx.Set("params", make(map[string]any)) luaCtx.Set("params", make(map[string]any))
} }
// Query parameters // Parse form data for POST/PUT/PATCH requests
queryMap := QueryToLua(ctx)
luaCtx.Set("query", queryMap)
// Form data
if method == "POST" || method == "PUT" || method == "PATCH" { if method == "POST" || method == "PUT" || method == "PATCH" {
formData, err := ParseForm(ctx) formData, err := ParseForm(ctx)
if err == nil && len(formData) > 0 { if err == nil && len(formData) > 0 {
@ -226,40 +198,26 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
luaCtx.Set("form", make(map[string]any)) luaCtx.Set("form", make(map[string]any))
} }
// Session handling // CSRF middleware for state-changing requests
cookieOpts := sessions.GlobalSessionManager.CookieOptions() if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
cookieName := cookieOpts["name"].(string) if !ValidateCSRFToken(luaCtx) {
sessionCookie := ctx.Request.Header.Cookie(cookieName) HandleCSRFError(ctx, s.errorConfig)
return
var sessionID string }
if sessionCookie != nil {
sessionID = string(sessionCookie)
} }
// Get or create session
var session *sessions.Session
if sessionID != "" {
session = sessions.GlobalSessionManager.GetSession(sessionID)
} else {
session = sessions.GlobalSessionManager.CreateSession()
}
// Set session in context
luaCtx.Session = session
// Execute Lua script // Execute Lua script
result, err := s.luaRunner.Run(bytecode, luaCtx, scriptPath) response, err := s.luaRunner.Run(bytecode, luaCtx, scriptPath)
// Special handling for CSRF error
if err != nil { if err != nil {
if csrfErr, ok := err.(*CSRFError); ok { logger.Error("Error executing Lua route: %v", err)
logger.Warning("CSRF error executing Lua route: %v", csrfErr)
// Special handling for specific errors
if errors.Is(err, ErrCSRFValidationFailed) {
HandleCSRFError(ctx, s.errorConfig) HandleCSRFError(ctx, s.errorConfig)
return return
} }
// Normal error handling // General error handling
logger.Error("Error executing Lua route: %v", err)
ctx.SetContentType("text/html; charset=utf-8") ctx.SetContentType("text/html; charset=utf-8")
ctx.SetStatusCode(fasthttp.StatusInternalServerError) ctx.SetStatusCode(fasthttp.StatusInternalServerError)
errorHTML := utils.InternalErrorPage(s.errorConfig, path, err.Error()) errorHTML := utils.InternalErrorPage(s.errorConfig, path, err.Error())
@ -267,129 +225,21 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
return return
} }
// Handle session updates if needed // Save session if modified
if luaCtx.SessionModified { if response.SessionModified {
sessions.GlobalSessionManager.SaveSession(luaCtx.Session) // Update session data
for k, v := range response.SessionData {
// Set session cookie session.Set(k, v)
cookie := fasthttp.AcquireCookie()
cookie.SetKey(cookieName)
cookie.SetValue(luaCtx.Session.ID)
cookie.SetPath(cookieOpts["path"].(string))
if domain, ok := cookieOpts["domain"].(string); ok && domain != "" {
cookie.SetDomain(domain)
} }
s.sessionManager.SaveSession(session)
if maxAge, ok := cookieOpts["max_age"].(int); ok { s.sessionManager.ApplySessionCookie(ctx, session)
cookie.SetMaxAge(maxAge)
}
cookie.SetSecure(cookieOpts["secure"].(bool))
cookie.SetHTTPOnly(cookieOpts["http_only"].(bool))
ctx.Response.Header.SetCookie(cookie)
fasthttp.ReleaseCookie(cookie)
} }
// If we got a non-nil result, write it to the response // Apply response to HTTP context
if result != nil { runner.ApplyResponse(response, ctx)
writeResponse(ctx, result)
}
}
// Content types for responses // Release the response when done
const ( runner.ReleaseResponse(response)
contentTypeJSON = "application/json"
contentTypePlain = "text/plain"
)
// writeResponse writes the Lua result to the HTTP response
func writeResponse(ctx *fasthttp.RequestCtx, result any) {
if result == nil {
ctx.SetStatusCode(fasthttp.StatusNoContent)
return
}
// First check the raw type of the result for strong type identification
// Sometimes type assertions don't work as expected with interface values
resultType := fmt.Sprintf("%T", result)
// Strong check for HTTP response
if strings.Contains(resultType, "HTTPResponse") || strings.Contains(resultType, "sandbox.HTTPResponse") {
httpResp, ok := result.(*sandbox.HTTPResponse)
if ok {
defer sandbox.ReleaseResponse(httpResp)
// Set response headers
for name, value := range httpResp.Headers {
ctx.Response.Header.Set(name, value)
}
// Set cookies
for _, cookie := range httpResp.Cookies {
ctx.Response.Header.SetCookie(cookie)
}
// Set status code
ctx.SetStatusCode(httpResp.Status)
// Process the body based on its type
if httpResp.Body == nil {
return
}
// Continue with the body only
result = httpResp.Body
} else {
// We identified it as HTTPResponse but couldn't convert it
// This is a programming error
logger.Error("Found HTTPResponse type but failed to convert: %v", resultType)
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
return
}
}
// Check if it's a map (table) or array - return as JSON
isJSON := false
switch result.(type) {
case map[string]any, []any, []float64, []string, []int:
isJSON = true
}
if isJSON {
setContentTypeIfMissing(ctx, contentTypeJSON)
data, err := json.Marshal(result)
if err != nil {
logger.Error("Failed to marshal response: %v", err)
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
return
}
ctx.SetBody(data)
return
}
// Handle string and byte slice cases directly
switch r := result.(type) {
case string:
setContentTypeIfMissing(ctx, contentTypePlain)
ctx.SetBodyString(r)
return
case []byte:
setContentTypeIfMissing(ctx, contentTypePlain)
ctx.SetBody(r)
return
}
// If we reach here, it's an unexpected type - convert to string as a last resort
setContentTypeIfMissing(ctx, contentTypePlain)
ctx.SetBodyString(fmt.Sprintf("%v", result))
}
func setContentTypeIfMissing(ctx *fasthttp.RequestCtx, contentType string) {
if len(ctx.Response.Header.ContentType()) == 0 {
ctx.SetContentType(contentType)
}
} }
// handleDebugStats displays debug statistics // handleDebugStats displays debug statistics
@ -399,12 +249,14 @@ func (s *Server) handleDebugStats(ctx *fasthttp.RequestCtx) {
// Add component stats // Add component stats
routeCount, bytecodeBytes := s.luaRouter.GetRouteStats() routeCount, bytecodeBytes := s.luaRouter.GetRouteStats()
moduleCount := s.luaRunner.GetModuleCount() //stateCount := s.luaRunner.GetStateCount()
//activeStates := s.luaRunner.GetActiveStateCount()
stats.Components = utils.ComponentStats{ stats.Components = utils.ComponentStats{
RouteCount: routeCount, RouteCount: routeCount,
BytecodeBytes: bytecodeBytes, BytecodeBytes: bytecodeBytes,
ModuleCount: moduleCount, //StatesCount: stateCount,
//ActiveStates: activeStates,
} }
// Generate HTML page // Generate HTML page

206
core/http/Utils.go Normal file
View File

@ -0,0 +1,206 @@
package http
import (
"crypto/rand"
"encoding/base64"
"fmt"
"mime/multipart"
"strings"
"time"
"Moonshark/core/utils/logger"
"github.com/valyala/fasthttp"
)
// LogRequest logs an HTTP request with its status code and duration
func LogRequest(statusCode int, method, path string, duration time.Duration) {
var statusColor, resetColor, methodColor string
// Status code colors
if statusCode >= 200 && statusCode < 300 {
statusColor = "\u001b[32m" // Green for 2xx
} else if statusCode >= 300 && statusCode < 400 {
statusColor = "\u001b[36m" // Cyan for 3xx
} else if statusCode >= 400 && statusCode < 500 {
statusColor = "\u001b[33m" // Yellow for 4xx
} else {
statusColor = "\u001b[31m" // Red for 5xx and others
}
// Method colors
switch method {
case "GET":
methodColor = "\u001b[32m" // Green
case "POST":
methodColor = "\u001b[34m" // Blue
case "PUT":
methodColor = "\u001b[33m" // Yellow
case "DELETE":
methodColor = "\u001b[31m" // Red
default:
methodColor = "\u001b[35m" // Magenta for others
}
resetColor = "\u001b[0m"
// Format duration
var durationStr string
if duration.Milliseconds() < 1 {
durationStr = fmt.Sprintf("%.2fµs", float64(duration.Microseconds()))
} else if duration.Milliseconds() < 1000 {
durationStr = fmt.Sprintf("%.2fms", float64(duration.Microseconds())/1000)
} else {
durationStr = fmt.Sprintf("%.2fs", duration.Seconds())
}
// Log with colors
logger.Server("%s%d%s %s%s%s %s %s",
statusColor, statusCode, resetColor,
methodColor, method, resetColor,
path, durationStr)
}
// QueryToLua converts HTTP query args to a Lua-friendly map
func QueryToLua(ctx *fasthttp.RequestCtx) map[string]any {
queryMap := make(map[string]any)
// Visit all query parameters
ctx.QueryArgs().VisitAll(func(key, value []byte) {
// Convert to string
k := string(key)
v := string(value)
// Check if this key already exists as an array
if existing, ok := queryMap[k]; ok {
// If it's already an array, append to it
if arr, ok := existing.([]string); ok {
queryMap[k] = append(arr, v)
} else if str, ok := existing.(string); ok {
// Convert existing string to array and append new value
queryMap[k] = []string{str, v}
}
} else {
// New key, store as string
queryMap[k] = v
}
})
return queryMap
}
// ParseForm extracts form data from a request
func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
formData := make(map[string]any)
// Check if multipart form
if strings.Contains(string(ctx.Request.Header.ContentType()), "multipart/form-data") {
return parseMultipartForm(ctx)
}
// Regular form
ctx.PostArgs().VisitAll(func(key, value []byte) {
k := string(key)
v := string(value)
// Check if this key already exists
if existing, ok := formData[k]; ok {
// If it's already an array, append to it
if arr, ok := existing.([]string); ok {
formData[k] = append(arr, v)
} else if str, ok := existing.(string); ok {
// Convert existing string to array and append new value
formData[k] = []string{str, v}
}
} else {
// New key, store as string
formData[k] = v
}
})
return formData, nil
}
// parseMultipartForm handles multipart/form-data requests
func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
formData := make(map[string]any)
// Parse multipart form
form, err := ctx.MultipartForm()
if err != nil {
return nil, err
}
// Process form values
for key, values := range form.Value {
if len(values) == 1 {
formData[key] = values[0]
} else if len(values) > 1 {
formData[key] = values
}
}
// Process files (store file info, not the content)
if len(form.File) > 0 {
files := make(map[string]any)
for fieldName, fileHeaders := range form.File {
if len(fileHeaders) == 1 {
files[fieldName] = fileInfoToMap(fileHeaders[0])
} else if len(fileHeaders) > 1 {
fileInfos := make([]map[string]any, 0, len(fileHeaders))
for _, fh := range fileHeaders {
fileInfos = append(fileInfos, fileInfoToMap(fh))
}
files[fieldName] = fileInfos
}
}
formData["_files"] = files
}
return formData, nil
}
// fileInfoToMap converts a FileHeader to a map for Lua
func fileInfoToMap(fh *multipart.FileHeader) map[string]any {
return map[string]any{
"filename": fh.Filename,
"size": fh.Size,
"mimetype": getMimeType(fh),
}
}
// getMimeType gets the mime type from a file header
func getMimeType(fh *multipart.FileHeader) string {
if fh.Header != nil {
contentType := fh.Header.Get("Content-Type")
if contentType != "" {
return contentType
}
}
// Fallback to basic type detection from filename
if strings.HasSuffix(fh.Filename, ".pdf") {
return "application/pdf"
} else if strings.HasSuffix(fh.Filename, ".png") {
return "image/png"
} else if strings.HasSuffix(fh.Filename, ".jpg") || strings.HasSuffix(fh.Filename, ".jpeg") {
return "image/jpeg"
} else if strings.HasSuffix(fh.Filename, ".gif") {
return "image/gif"
} else if strings.HasSuffix(fh.Filename, ".svg") {
return "image/svg+xml"
}
return "application/octet-stream"
}
// GenerateSecureToken creates a cryptographically secure random token
func GenerateSecureToken(length int) (string, error) {
b := make([]byte, length)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b)[:length], nil
}

View File

@ -3,8 +3,6 @@ package runner
import ( import (
"sync" "sync"
"Moonshark/core/sessions"
"github.com/valyala/bytebufferpool" "github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -17,9 +15,9 @@ type Context struct {
// FastHTTP context if this was created from an HTTP request // FastHTTP context if this was created from an HTTP request
RequestCtx *fasthttp.RequestCtx RequestCtx *fasthttp.RequestCtx
// Session data and management // Session information
Session *sessions.Session SessionID string
SessionModified bool SessionData map[string]any
// Buffer for efficient string operations // Buffer for efficient string operations
buffer *bytebufferpool.ByteBuffer buffer *bytebufferpool.ByteBuffer
@ -29,7 +27,8 @@ type Context struct {
var contextPool = sync.Pool{ var contextPool = sync.Pool{
New: func() any { New: func() any {
return &Context{ return &Context{
Values: make(map[string]any, 16), Values: make(map[string]any, 16),
SessionData: make(map[string]any, 8),
} }
}, },
} }
@ -43,6 +42,44 @@ func NewContext() *Context {
func NewHTTPContext(requestCtx *fasthttp.RequestCtx) *Context { func NewHTTPContext(requestCtx *fasthttp.RequestCtx) *Context {
ctx := NewContext() ctx := NewContext()
ctx.RequestCtx = requestCtx ctx.RequestCtx = requestCtx
// Extract common HTTP values that Lua might need
if requestCtx != nil {
ctx.Values["_request_method"] = string(requestCtx.Method())
ctx.Values["_request_path"] = string(requestCtx.Path())
ctx.Values["_request_url"] = string(requestCtx.RequestURI())
// Extract cookies
cookies := make(map[string]any)
requestCtx.Request.Header.VisitAllCookie(func(key, value []byte) {
cookies[string(key)] = string(value)
})
ctx.Values["_request_cookies"] = cookies
// Extract query params
query := make(map[string]any)
requestCtx.QueryArgs().VisitAll(func(key, value []byte) {
query[string(key)] = string(value)
})
ctx.Values["_request_query"] = query
// Extract form data if present
if requestCtx.IsPost() || requestCtx.IsPut() {
form := make(map[string]any)
requestCtx.PostArgs().VisitAll(func(key, value []byte) {
form[string(key)] = string(value)
})
ctx.Values["_request_form"] = form
}
// Extract headers
headers := make(map[string]any)
requestCtx.Request.Header.VisitAll(func(key, value []byte) {
headers[string(key)] = string(value)
})
ctx.Values["_request_headers"] = headers
}
return ctx return ctx
} }
@ -53,9 +90,12 @@ func (c *Context) Release() {
delete(c.Values, k) delete(c.Values, k)
} }
for k := range c.SessionData {
delete(c.SessionData, k)
}
// Reset session info // Reset session info
c.Session = nil c.SessionID = ""
c.SessionModified = false
// Reset request context // Reset request context
c.RequestCtx = nil c.RequestCtx = nil
@ -87,13 +127,12 @@ func (c *Context) Get(key string) any {
return c.Values[key] return c.Values[key]
} }
// Contains checks if a key exists in the context // SetSession sets a session data value
func (c *Context) Contains(key string) bool { func (c *Context) SetSession(key string, value any) {
_, exists := c.Values[key] c.SessionData[key] = value
return exists
} }
// Delete removes a value from the context // GetSession retrieves a session data value
func (c *Context) Delete(key string) { func (c *Context) GetSession(key string) any {
delete(c.Values, key) return c.SessionData[key]
} }

View File

@ -1,262 +0,0 @@
package runner
import (
"Moonshark/core/runner/sandbox"
"Moonshark/core/utils/logger"
"fmt"
"strings"
"sync"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// CoreModuleRegistry manages the initialization and reloading of core modules
type CoreModuleRegistry struct {
modules map[string]sandbox.StateInitFunc // Module initializers
initOrder []string // Explicit initialization order
dependencies map[string][]string // Module dependencies
initializedFlag map[string]bool // Track which modules are initialized
mu sync.RWMutex
debug bool
}
// NewCoreModuleRegistry creates a new core module registry
func NewCoreModuleRegistry() *CoreModuleRegistry {
return &CoreModuleRegistry{
modules: make(map[string]sandbox.StateInitFunc),
initOrder: []string{},
dependencies: make(map[string][]string),
initializedFlag: make(map[string]bool),
debug: false,
}
}
// EnableDebug turns on debug logging
func (r *CoreModuleRegistry) EnableDebug() {
r.debug = true
}
// debugLog prints debug messages if enabled
func (r *CoreModuleRegistry) debugLog(format string, args ...interface{}) {
if r.debug {
logger.Debug("CoreRegistry "+format, args...)
}
}
// Register adds a module to the registry
func (r *CoreModuleRegistry) Register(name string, initFunc sandbox.StateInitFunc) {
r.mu.Lock()
defer r.mu.Unlock()
r.modules[name] = initFunc
// Add to initialization order if not already there
for _, n := range r.initOrder {
if n == name {
return // Already registered
}
}
r.initOrder = append(r.initOrder, name)
r.debugLog("registered module %s", name)
}
// RegisterWithDependencies registers a module with explicit dependencies
func (r *CoreModuleRegistry) RegisterWithDependencies(name string, initFunc sandbox.StateInitFunc, dependencies []string) {
r.mu.Lock()
defer r.mu.Unlock()
r.modules[name] = initFunc
r.dependencies[name] = dependencies
// Add to initialization order if not already there
for _, n := range r.initOrder {
if n == name {
return // Already registered
}
}
r.initOrder = append(r.initOrder, name)
r.debugLog("registered module %s with dependencies: %v", name, dependencies)
}
// SetInitOrder sets explicit initialization order
func (r *CoreModuleRegistry) SetInitOrder(order []string) {
r.mu.Lock()
defer r.mu.Unlock()
// Create new init order
newOrder := make([]string, 0, len(order))
// First add all known modules that are in the specified order
for _, name := range order {
if _, exists := r.modules[name]; exists && !contains(newOrder, name) {
newOrder = append(newOrder, name)
}
}
// Then add any modules not in the specified order
for name := range r.modules {
if !contains(newOrder, name) {
newOrder = append(newOrder, name)
}
}
r.initOrder = newOrder
r.debugLog("Set initialization order: %v", r.initOrder)
}
// Initialize initializes all registered modules
func (r *CoreModuleRegistry) Initialize(state *luajit.State, stateIndex int) error {
r.mu.RLock()
defer r.mu.RUnlock()
verbose := stateIndex == 0
if verbose {
r.debugLog("initializing %d modules...", len(r.initOrder))
}
// Clear initialization flags
r.initializedFlag = make(map[string]bool)
// Initialize modules in order, respecting dependencies
for _, name := range r.initOrder {
if err := r.initializeModule(state, name, []string{}, verbose); err != nil {
return err
}
}
if verbose {
r.debugLog("All modules initialized successfully")
}
return nil
}
// initializeModule initializes a module and its dependencies
func (r *CoreModuleRegistry) initializeModule(state *luajit.State, name string,
initStack []string, verbose bool) error {
// Check if already initialized
if r.initializedFlag[name] {
return nil
}
// Check for circular dependencies
for _, n := range initStack {
if n == name {
return fmt.Errorf("circular dependency detected: %s -> %s",
strings.Join(initStack, " -> "), name)
}
}
// Get init function
initFunc, ok := r.modules[name]
if !ok {
return fmt.Errorf("module not found: %s", name)
}
// Initialize dependencies first
deps := r.dependencies[name]
if len(deps) > 0 {
newStack := append(initStack, name)
for _, dep := range deps {
if err := r.initializeModule(state, dep, newStack, verbose); err != nil {
return err
}
}
}
err := initFunc(state)
if err != nil {
// Always log failures regardless of verbose setting
r.debugLog("Initializing module %s... failure: %v", name, err)
return fmt.Errorf("failed to initialize module %s: %w", name, err)
}
r.initializedFlag[name] = true
if verbose {
r.debugLog("Initializing module %s... success", name)
}
return nil
}
// InitializeModule initializes a specific module
func (r *CoreModuleRegistry) InitializeModule(state *luajit.State, name string) error {
r.mu.RLock()
defer r.mu.RUnlock()
// Clear initialization flag for this module
r.initializedFlag[name] = false
// Always use verbose logging for explicit module initialization
return r.initializeModule(state, name, []string{}, true)
}
// MatchModuleName checks if a file path corresponds to a registered module
func (r *CoreModuleRegistry) MatchModuleName(modName string) (string, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
// Exact match
if _, ok := r.modules[modName]; ok {
return modName, true
}
// Check if the module name ends with a registered module
for name := range r.modules {
if strings.HasSuffix(modName, "."+name) {
return name, true
}
}
return "", false
}
// Global registry instance
var GlobalRegistry = NewCoreModuleRegistry()
// Initialize global registry with core modules
func init() {
GlobalRegistry.EnableDebug() // Enable debugging by default
logger.Debug("[ModuleRegistry] Registering core modules...")
// Register core modules
GlobalRegistry.Register("util", func(state *luajit.State) error {
return sandbox.UtilModuleInitFunc()(state)
})
GlobalRegistry.Register("http", func(state *luajit.State) error {
return sandbox.HTTPModuleInitFunc()(state)
})
// Set explicit initialization order
GlobalRegistry.SetInitOrder([]string{
"util", // First: core utilities
"http", // Second: HTTP functionality
"session", // Third: Session functionality
"csrf", // Fourth: CSRF protection
})
logger.Debug("Core modules registered successfully")
}
// RegisterCoreModule registers a core module with the global registry
func RegisterCoreModule(name string, initFunc sandbox.StateInitFunc) {
GlobalRegistry.Register(name, initFunc)
}
// RegisterCoreModuleWithDependencies registers a module with dependencies
func RegisterCoreModuleWithDependencies(name string, initFunc sandbox.StateInitFunc, dependencies []string) {
GlobalRegistry.RegisterWithDependencies(name, initFunc, dependencies)
}
// Helper functions
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}

61
core/runner/Embed.go Normal file
View File

@ -0,0 +1,61 @@
package runner
import (
_ "embed"
"sync"
"sync/atomic"
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
//go:embed sandbox.lua
var sandboxLuaCode string
// Global bytecode cache to improve performance
var (
sandboxBytecode atomic.Pointer[[]byte]
bytecodeOnce sync.Once
)
// precompileSandboxCode compiles the sandbox.lua code to bytecode once
func precompileSandboxCode() {
// Create temporary state for compilation
tempState := luajit.New()
if tempState == nil {
logger.Error("Failed to create temp Lua state for bytecode compilation")
return
}
defer tempState.Close()
defer tempState.Cleanup()
code, err := tempState.CompileBytecode(sandboxLuaCode, "sandbox.lua")
if err != nil {
logger.Error("Failed to compile sandbox code: %v", err)
return
}
bytecode := make([]byte, len(code))
copy(bytecode, code)
sandboxBytecode.Store(&bytecode)
logger.Debug("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(code))
}
// loadSandboxIntoState loads the sandbox code into a Lua state
func loadSandboxIntoState(state *luajit.State) error {
// Initialize bytecode once
bytecodeOnce.Do(precompileSandboxCode)
// Use precompiled bytecode if available
bytecode := sandboxBytecode.Load()
if bytecode != nil && len(*bytecode) > 0 {
logger.Debug("Loading sandbox.lua from precompiled bytecode")
return state.LoadAndRunBytecode(*bytecode, "sandbox.lua")
}
// Fallback to direct execution
logger.Warning("Using non-precompiled sandbox.lua (bytecode compilation failed)")
return state.DoString(sandboxLuaCode)
}

334
core/runner/Http.go Normal file
View File

@ -0,0 +1,334 @@
package runner
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"net/url"
"strings"
"time"
"github.com/goccy/go-json"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Default HTTP client with sensible timeout
var defaultFastClient = fasthttp.Client{
MaxConnsPerHost: 1024,
MaxIdleConnDuration: time.Minute,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
DisableHeaderNamesNormalizing: true,
}
// HTTPClientConfig contains client settings
type HTTPClientConfig struct {
MaxTimeout time.Duration // Maximum timeout for requests (0 = no limit)
DefaultTimeout time.Duration // Default request timeout
MaxResponseSize int64 // Maximum response size in bytes (0 = no limit)
AllowRemote bool // Whether to allow remote connections
}
// DefaultHTTPClientConfig provides sensible defaults
var DefaultHTTPClientConfig = HTTPClientConfig{
MaxTimeout: 60 * time.Second,
DefaultTimeout: 30 * time.Second,
MaxResponseSize: 10 * 1024 * 1024, // 10MB
AllowRemote: true,
}
// ApplyResponse applies a Response to a fasthttp.RequestCtx
func ApplyResponse(resp *Response, ctx *fasthttp.RequestCtx) {
// Set status code
ctx.SetStatusCode(resp.Status)
// Set headers
for name, value := range resp.Headers {
ctx.Response.Header.Set(name, value)
}
// Set cookies
for _, cookie := range resp.Cookies {
ctx.Response.Header.SetCookie(cookie)
}
// Process the body based on its type
if resp.Body == nil {
return
}
// Get a buffer from the pool
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
// Set body based on type
switch body := resp.Body.(type) {
case string:
ctx.SetBodyString(body)
case []byte:
ctx.SetBody(body)
case map[string]any, []any, []float64, []string, []int:
// Marshal JSON
if err := json.NewEncoder(buf).Encode(body); err == nil {
// Set content type if not already set
if len(ctx.Response.Header.ContentType()) == 0 {
ctx.Response.Header.SetContentType("application/json")
}
ctx.SetBody(buf.Bytes())
} else {
// Fallback
ctx.SetBodyString(fmt.Sprintf("%v", body))
}
default:
// Default to string representation
ctx.SetBodyString(fmt.Sprintf("%v", body))
}
}
// httpRequest makes an HTTP request and returns the result to Lua
func httpRequest(state *luajit.State) int {
// Get method (required)
if !state.IsString(1) {
state.PushString("http.client.request: method must be a string")
return -1
}
method := strings.ToUpper(state.ToString(1))
// Get URL (required)
if !state.IsString(2) {
state.PushString("http.client.request: url must be a string")
return -1
}
urlStr := state.ToString(2)
// Parse URL to check if it's valid
parsedURL, err := url.Parse(urlStr)
if err != nil {
state.PushString("Invalid URL: " + err.Error())
return -1
}
// Get client configuration
config := DefaultHTTPClientConfig
// Check if remote connections are allowed
if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") {
state.PushString("Remote connections are not allowed")
return -1
}
// Use bytebufferpool for request and response
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// Set up request
req.Header.SetMethod(method)
req.SetRequestURI(urlStr)
req.Header.Set("User-Agent", "Moonshark/1.0")
// Get body (optional)
if state.GetTop() >= 3 && !state.IsNil(3) {
if state.IsString(3) {
// String body
req.SetBodyString(state.ToString(3))
} else if state.IsTable(3) {
// Table body - convert to JSON
luaTable, err := state.ToTable(3)
if err != nil {
state.PushString("Failed to parse body table: " + err.Error())
return -1
}
// Use bytebufferpool for JSON serialization
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
if err := json.NewEncoder(buf).Encode(luaTable); err != nil {
state.PushString("Failed to convert body to JSON: " + err.Error())
return -1
}
req.SetBody(buf.Bytes())
req.Header.SetContentType("application/json")
} else {
state.PushString("Body must be a string or table")
return -1
}
}
// Process options (headers, timeout, etc.)
timeout := config.DefaultTimeout
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsTable(4) {
// Process headers
state.GetField(4, "headers")
if state.IsTable(-1) {
// Iterate through headers
state.PushNil() // Start iteration
for state.Next(-2) {
// Stack now has key at -2 and value at -1
if state.IsString(-2) && state.IsString(-1) {
headerName := state.ToString(-2)
headerValue := state.ToString(-1)
req.Header.Set(headerName, headerValue)
}
state.Pop(1) // Pop value, leave key for next iteration
}
}
state.Pop(1) // Pop headers table
// Get timeout
state.GetField(4, "timeout")
if state.IsNumber(-1) {
requestTimeout := time.Duration(state.ToNumber(-1)) * time.Second
// Apply max timeout if configured
if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout {
timeout = config.MaxTimeout
} else {
timeout = requestTimeout
}
}
state.Pop(1) // Pop timeout
// Process query parameters
state.GetField(4, "query")
if state.IsTable(-1) {
// Create URL args
args := req.URI().QueryArgs()
// Iterate through query params
state.PushNil() // Start iteration
for state.Next(-2) {
if state.IsString(-2) {
paramName := state.ToString(-2)
// Handle different value types
if state.IsString(-1) {
args.Add(paramName, state.ToString(-1))
} else if state.IsNumber(-1) {
args.Add(paramName, strings.TrimRight(strings.TrimRight(
state.ToString(-1), "0"), "."))
} else if state.IsBoolean(-1) {
if state.ToBoolean(-1) {
args.Add(paramName, "true")
} else {
args.Add(paramName, "false")
}
}
}
state.Pop(1) // Pop value, leave key for next iteration
}
}
state.Pop(1) // Pop query table
}
// Create context with timeout
_, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// Execute request
err = defaultFastClient.DoTimeout(req, resp, timeout)
if err != nil {
errStr := "Request failed: " + err.Error()
if errors.Is(err, fasthttp.ErrTimeout) {
errStr = "Request timed out after " + timeout.String()
}
state.PushString(errStr)
return -1
}
// Create response table
state.NewTable()
// Set status code
state.PushNumber(float64(resp.StatusCode()))
state.SetField(-2, "status")
// Set status text
statusText := fasthttp.StatusMessage(resp.StatusCode())
state.PushString(statusText)
state.SetField(-2, "status_text")
// Set body
var respBody []byte
// Apply size limits to response
if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize {
// Make a limited copy
respBody = make([]byte, config.MaxResponseSize)
copy(respBody, resp.Body())
} else {
respBody = resp.Body()
}
state.PushString(string(respBody))
state.SetField(-2, "body")
// Parse body as JSON if content type is application/json
contentType := string(resp.Header.ContentType())
if strings.Contains(contentType, "application/json") {
var jsonData any
if err := json.Unmarshal(respBody, &jsonData); err == nil {
if err := state.PushValue(jsonData); err == nil {
state.SetField(-2, "json")
}
}
}
// Set headers
state.NewTable()
resp.Header.VisitAll(func(key, value []byte) {
state.PushString(string(value))
state.SetField(-2, string(key))
})
state.SetField(-2, "headers")
// Create ok field (true if status code is 2xx)
state.PushBoolean(resp.StatusCode() >= 200 && resp.StatusCode() < 300)
state.SetField(-2, "ok")
return 1
}
// generateToken creates a cryptographically secure random token
func generateToken(state *luajit.State) int {
// Get the length from the Lua arguments (default to 32)
length := 32
if state.GetTop() >= 1 && state.IsNumber(1) {
length = int(state.ToNumber(1))
}
// Enforce minimum length for security
if length < 16 {
length = 16
}
// Generate secure random bytes
tokenBytes := make([]byte, length)
if _, err := rand.Read(tokenBytes); err != nil {
logger.Error("Failed to generate secure token: %v", err)
state.PushString("")
return 1 // Return empty string on error
}
// Encode as base64
token := base64.RawURLEncoding.EncodeToString(tokenBytes)
// Trim to requested length (base64 might be longer)
if len(token) > length {
token = token[:length]
}
// Push the token to the Lua stack
state.PushString(token)
return 1 // One return value
}

View File

@ -6,6 +6,8 @@ import (
"strings" "strings"
"sync" "sync"
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go" luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
) )
@ -15,61 +17,15 @@ type ModuleConfig struct {
LibDirs []string // Additional library directories LibDirs []string // Additional library directories
} }
// ModuleInfo stores information about a loaded module
type ModuleInfo struct {
Name string
Path string
IsCore bool
Bytecode []byte
}
// ModuleLoader manages module loading and caching // ModuleLoader manages module loading and caching
type ModuleLoader struct { type ModuleLoader struct {
config *ModuleConfig config *ModuleConfig
registry *ModuleRegistry
pathCache map[string]string // Cache module paths for fast lookups pathCache map[string]string // Cache module paths for fast lookups
bytecodeCache map[string][]byte // Cache of compiled bytecode bytecodeCache map[string][]byte // Cache of compiled bytecode
debug bool debug bool
mu sync.RWMutex mu sync.RWMutex
} }
// ModuleRegistry keeps track of Lua modules for file watching
type ModuleRegistry struct {
// Maps file paths to module names
pathToModule sync.Map
// Maps module names to file paths
moduleToPath sync.Map
}
// NewModuleRegistry creates a new module registry
func NewModuleRegistry() *ModuleRegistry {
return &ModuleRegistry{}
}
// Register adds a module path to the registry
func (r *ModuleRegistry) Register(path string, name string) {
r.pathToModule.Store(path, name)
r.moduleToPath.Store(name, path)
}
// GetModuleName retrieves a module name by path
func (r *ModuleRegistry) GetModuleName(path string) (string, bool) {
value, ok := r.pathToModule.Load(path)
if !ok {
return "", false
}
return value.(string), true
}
// GetModulePath retrieves a path by module name
func (r *ModuleRegistry) GetModulePath(name string) (string, bool) {
value, ok := r.moduleToPath.Load(name)
if !ok {
return "", false
}
return value.(string), true
}
// NewModuleLoader creates a new module loader // NewModuleLoader creates a new module loader
func NewModuleLoader(config *ModuleConfig) *ModuleLoader { func NewModuleLoader(config *ModuleConfig) *ModuleLoader {
if config == nil { if config == nil {
@ -81,7 +37,6 @@ func NewModuleLoader(config *ModuleConfig) *ModuleLoader {
return &ModuleLoader{ return &ModuleLoader{
config: config, config: config,
registry: NewModuleRegistry(),
pathCache: make(map[string]string), pathCache: make(map[string]string),
bytecodeCache: make(map[string][]byte), bytecodeCache: make(map[string][]byte),
debug: false, debug: false,
@ -100,6 +55,13 @@ func (l *ModuleLoader) SetScriptDir(dir string) {
l.config.ScriptDir = dir l.config.ScriptDir = dir
} }
// debugLog logs a message if debug mode is enabled
func (l *ModuleLoader) debugLog(format string, args ...interface{}) {
if l.debug {
logger.Debug("ModuleLoader "+format, args...)
}
}
// SetupRequire configures the require system in a Lua state // SetupRequire configures the require system in a Lua state
func (l *ModuleLoader) SetupRequire(state *luajit.State) error { func (l *ModuleLoader) SetupRequire(state *luajit.State) error {
l.mu.RLock() l.mu.RLock()
@ -207,6 +169,8 @@ func (l *ModuleLoader) PreloadModules(state *luajit.State) error {
continue continue
} }
l.debugLog("Scanning directory: %s", absDir)
// Find all Lua files // Find all Lua files
err = filepath.Walk(absDir, func(path string, info os.FileInfo, err error) error { err = filepath.Walk(absDir, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() || !strings.HasSuffix(path, ".lua") { if err != nil || info.IsDir() || !strings.HasSuffix(path, ".lua") {
@ -223,19 +187,22 @@ func (l *ModuleLoader) PreloadModules(state *luajit.State) error {
modName := strings.TrimSuffix(relPath, ".lua") modName := strings.TrimSuffix(relPath, ".lua")
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".") modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
l.debugLog("Found module: %s at %s", modName, path)
// Register in our caches // Register in our caches
l.pathCache[modName] = path l.pathCache[modName] = path
l.registry.Register(path, modName)
// Load file content // Load file content
content, err := os.ReadFile(path) content, err := os.ReadFile(path)
if err != nil { if err != nil {
l.debugLog("Failed to read module file: %v", err)
return nil return nil
} }
// Compile to bytecode // Compile to bytecode
bytecode, err := state.CompileBytecode(string(content), path) bytecode, err := state.CompileBytecode(string(content), path)
if err != nil { if err != nil {
l.debugLog("Failed to compile module: %v", err)
return nil return nil
} }
@ -354,10 +321,11 @@ func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) {
// Clean path for proper comparison // Clean path for proper comparison
path = filepath.Clean(path) path = filepath.Clean(path)
// Try direct lookup from registry // Try direct lookup from cache
modName, found := l.registry.GetModuleName(path) for modName, modPath := range l.pathCache {
if found { if modPath == path {
return modName, true return modName, true
}
} }
// Try to find by relative path from lib dirs // Try to find by relative path from lib dirs
@ -373,7 +341,7 @@ func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) {
} }
if strings.HasSuffix(relPath, ".lua") { if strings.HasSuffix(relPath, ".lua") {
modName = strings.TrimSuffix(relPath, ".lua") modName := strings.TrimSuffix(relPath, ".lua")
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".") modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
return modName, true return modName, true
} }
@ -382,103 +350,6 @@ func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) {
return "", false return "", false
} }
// ReloadModule reloads a module from disk
func (l *ModuleLoader) ReloadModule(state *luajit.State, name string) (bool, error) {
l.mu.Lock()
defer l.mu.Unlock()
// Get module path
path, ok := l.registry.GetModulePath(name)
if !ok {
for modName, modPath := range l.pathCache {
if modName == name {
path = modPath
ok = true
break
}
}
}
if !ok || path == "" {
return false, nil
}
// Invalidate module in Lua
err := state.DoString(`
package.loaded["` + name + `"] = nil
__ready_modules["` + name + `"] = nil
if package.preload then
package.preload["` + name + `"] = nil
end
`)
if err != nil {
return false, err
}
// Check if file still exists
if _, err := os.Stat(path); os.IsNotExist(err) {
// File was deleted, just invalidate
delete(l.pathCache, name)
delete(l.bytecodeCache, name)
l.registry.moduleToPath.Delete(name)
l.registry.pathToModule.Delete(path)
return true, nil
}
// Read updated file
content, err := os.ReadFile(path)
if err != nil {
return false, err
}
// Compile to bytecode
bytecode, err := state.CompileBytecode(string(content), path)
if err != nil {
return false, err
}
// Update cache
l.bytecodeCache[name] = bytecode
// Load bytecode into state
if err := state.LoadBytecode(bytecode, path); err != nil {
return false, err
}
// Update preload
luaCode := `
local modname = "` + name + `"
package.loaded[modname] = nil
package.preload[modname] = ...
__ready_modules[modname] = true
`
if err := state.DoString(luaCode); err != nil {
state.Pop(1) // Remove chunk from stack
return false, err
}
state.Pop(1) // Remove chunk from stack
return true, nil
}
// ResetModules clears non-core modules from package.loaded
func (l *ModuleLoader) ResetModules(state *luajit.State) error {
return state.DoString(`
local core_modules = {
string = true, table = true, math = true, os = true,
package = true, io = true, coroutine = true, debug = true, _G = true
}
for name in pairs(package.loaded) do
if not core_modules[name] then
package.loaded[name] = nil
end
end
`)
}
// escapeLuaString escapes special characters in a string for Lua // escapeLuaString escapes special characters in a string for Lua
func escapeLuaString(s string) string { func escapeLuaString(s string) string {
replacer := strings.NewReplacer( replacer := strings.NewReplacer(

76
core/runner/Response.go Normal file
View File

@ -0,0 +1,76 @@
package runner
import (
"sync"
"github.com/valyala/fasthttp"
)
// Response represents a unified response from script execution
type Response struct {
// Basic properties
Body any // Body content (any type)
Metadata map[string]any // Additional metadata
// HTTP specific properties
Status int // HTTP status code
Headers map[string]string // HTTP headers
Cookies []*fasthttp.Cookie // HTTP cookies
// Session information
SessionID string // Session ID
SessionData map[string]any // Session data
SessionModified bool // Whether session was modified
}
// Response pool to reduce allocations
var responsePool = sync.Pool{
New: func() any {
return &Response{
Status: 200,
Headers: make(map[string]string, 8),
Metadata: make(map[string]any, 8),
Cookies: make([]*fasthttp.Cookie, 0, 4),
SessionData: make(map[string]any, 8),
}
},
}
// NewResponse creates a new response object from the pool
func NewResponse() *Response {
return responsePool.Get().(*Response)
}
// Release returns a response to the pool after cleaning it
func ReleaseResponse(resp *Response) {
if resp == nil {
return
}
// Reset fields to default values
resp.Body = nil
resp.Status = 200
// Clear maps
for k := range resp.Headers {
delete(resp.Headers, k)
}
for k := range resp.Metadata {
delete(resp.Metadata, k)
}
for k := range resp.SessionData {
delete(resp.SessionData, k)
}
// Clear cookies
resp.Cookies = resp.Cookies[:0]
// Reset session info
resp.SessionID = ""
resp.SessionModified = false
// Return to pool
responsePool.Put(resp)
}

View File

@ -9,8 +9,6 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
luaCtx "Moonshark/core/runner/context"
"Moonshark/core/runner/sandbox"
"Moonshark/core/utils/logger" "Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go" luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
@ -29,30 +27,22 @@ type RunnerOption func(*Runner)
// State wraps a Lua state with its sandbox // State wraps a Lua state with its sandbox
type State struct { type State struct {
L *luajit.State // The Lua state L *luajit.State // The Lua state
sandbox *sandbox.Sandbox // Associated sandbox sandbox *Sandbox // Associated sandbox
index int // Index for debugging index int // Index for debugging
inUse bool // Whether the state is currently in use inUse bool // Whether the state is currently in use
} }
// InitHook runs before executing a script
type InitHook func(*luajit.State, *luaCtx.Context) error
// FinalizeHook runs after executing a script
type FinalizeHook func(*luajit.State, *luaCtx.Context, any) error
// Runner runs Lua scripts using a pool of Lua states // Runner runs Lua scripts using a pool of Lua states
type Runner struct { type Runner struct {
states []*State // All states managed by this runner states []*State // All states managed by this runner
statePool chan int // Pool of available state indexes statePool chan int // Pool of available state indexes
poolSize int // Size of the state pool poolSize int // Size of the state pool
moduleLoader *ModuleLoader // Module loader moduleLoader *ModuleLoader // Module loader
isRunning atomic.Bool // Whether the runner is active isRunning atomic.Bool // Whether the runner is active
mu sync.RWMutex // Mutex for thread safety mu sync.RWMutex // Mutex for thread safety
debug bool // Enable debug logging debug bool // Enable debug logging
initHooks []InitHook // Hooks run before script execution scriptDir string // Current script directory
finalizeHooks []FinalizeHook // Hooks run after script execution
scriptDir string // Current script directory
} }
// WithPoolSize sets the state pool size // WithPoolSize sets the state pool size
@ -84,28 +74,12 @@ func WithLibDirs(dirs ...string) RunnerOption {
} }
} }
// WithInitHook adds a hook to run before script execution
func WithInitHook(hook InitHook) RunnerOption {
return func(r *Runner) {
r.initHooks = append(r.initHooks, hook)
}
}
// WithFinalizeHook adds a hook to run after script execution
func WithFinalizeHook(hook FinalizeHook) RunnerOption {
return func(r *Runner) {
r.finalizeHooks = append(r.finalizeHooks, hook)
}
}
// NewRunner creates a new Runner with a pool of states // NewRunner creates a new Runner with a pool of states
func NewRunner(options ...RunnerOption) (*Runner, error) { func NewRunner(options ...RunnerOption) (*Runner, error) {
// Default configuration // Default configuration
runner := &Runner{ runner := &Runner{
poolSize: runtime.GOMAXPROCS(0), poolSize: runtime.GOMAXPROCS(0),
debug: false, debug: false,
initHooks: make([]InitHook, 0, 4),
finalizeHooks: make([]FinalizeHook, 0, 4),
} }
// Apply options // Apply options
@ -122,6 +96,11 @@ func NewRunner(options ...RunnerOption) (*Runner, error) {
runner.moduleLoader = NewModuleLoader(config) runner.moduleLoader = NewModuleLoader(config)
} }
// Enable debug if requested
if runner.debug {
runner.moduleLoader.EnableDebug()
}
// Initialize states and pool // Initialize states and pool
runner.states = make([]*State, runner.poolSize) runner.states = make([]*State, runner.poolSize)
runner.statePool = make(chan int, runner.poolSize) runner.statePool = make(chan int, runner.poolSize)
@ -145,7 +124,7 @@ func (r *Runner) debugLog(format string, args ...interface{}) {
// initializeStates creates and initializes all states in the pool // initializeStates creates and initializes all states in the pool
func (r *Runner) initializeStates() error { func (r *Runner) initializeStates() error {
r.debugLog("is initializing %d states", r.poolSize) r.debugLog("Initializing %d states", r.poolSize)
// Create all states // Create all states
for i := 0; i < r.poolSize; i++ { for i := 0; i < r.poolSize; i++ {
@ -175,39 +154,36 @@ func (r *Runner) createState(index int) (*State, error) {
} }
// Create sandbox // Create sandbox
sb := sandbox.NewSandbox() sb := NewSandbox()
if r.debug && verbose { if r.debug {
sb.EnableDebug() sb.EnableDebug()
} }
// Set up require system // Set up sandbox
if err := sb.Setup(L); err != nil {
L.Cleanup()
L.Close()
return nil, ErrInitFailed
}
// Set up module loader
if err := r.moduleLoader.SetupRequire(L); err != nil { if err := r.moduleLoader.SetupRequire(L); err != nil {
L.Cleanup() L.Cleanup()
L.Close() L.Close()
return nil, ErrInitFailed return nil, ErrInitFailed
} }
// Initialize all core modules from the registry // Preload modules
if err := GlobalRegistry.Initialize(L, index); err != nil {
L.Cleanup()
L.Close()
return nil, ErrInitFailed
}
// Set up sandbox after core modules are initialized
if err := sb.Setup(L, index); err != nil {
L.Cleanup()
L.Close()
return nil, ErrInitFailed
}
// Preload all modules
if err := r.moduleLoader.PreloadModules(L); err != nil { if err := r.moduleLoader.PreloadModules(L); err != nil {
L.Cleanup() L.Cleanup()
L.Close() L.Close()
return nil, errors.New("failed to preload modules") return nil, errors.New("failed to preload modules")
} }
if verbose {
r.debugLog("Lua state %d initialized successfully", index)
}
return &State{ return &State{
L: L, L: L,
sandbox: sb, sandbox: sb,
@ -216,8 +192,8 @@ func (r *Runner) createState(index int) (*State, error) {
}, nil }, nil
} }
// Execute runs a script with context // Execute runs a script in a sandbox with context
func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *luaCtx.Context, scriptPath string) (any, error) { func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (*Response, error) {
if !r.isRunning.Load() { if !r.isRunning.Load() {
return nil, ErrRunnerClosed return nil, ErrRunnerClosed
} }
@ -264,70 +240,17 @@ func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *luaCtx.C
} }
}() }()
// Run init hooks // Execute in sandbox
for _, hook := range r.initHooks { response, err := state.sandbox.Execute(state.L, bytecode, execCtx)
if err := hook(state.L, execCtx); err != nil {
return nil, err
}
}
// Get context values
var ctxValues map[string]any
if execCtx != nil {
ctxValues = execCtx.Values
}
// Execute in sandbox with optimized context handling
var result any
var err error
if execCtx != nil && execCtx.RequestCtx != nil {
// Use OptimizedExecute directly with the full context if we have RequestCtx
result, err = state.sandbox.OptimizedExecute(state.L, bytecode, &luaCtx.Context{
Values: ctxValues,
RequestCtx: execCtx.RequestCtx,
})
} else {
// Otherwise use standard Execute with just values
result, err = state.sandbox.Execute(state.L, bytecode, ctxValues)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Run finalize hooks return response, nil
for _, hook := range r.finalizeHooks {
if hookErr := hook(state.L, execCtx, result); hookErr != nil {
return nil, hookErr
}
}
// Check for HTTP response if we don't have a RequestCtx or if we still have a result
if execCtx == nil || execCtx.RequestCtx == nil || result != nil {
httpResp, hasResponse := sandbox.GetHTTPResponse(state.L)
if hasResponse {
// Set result as body if not already set
if httpResp.Body == nil {
httpResp.Body = result
}
// Apply directly to request context if available
if execCtx != nil && execCtx.RequestCtx != nil {
sandbox.ApplyHTTPResponse(httpResp, execCtx.RequestCtx)
sandbox.ReleaseResponse(httpResp)
return nil, nil
}
return httpResp, nil
}
}
return result, err
} }
// Run executes a Lua script (convenience wrapper) // Run executes a Lua script with immediate context
func (r *Runner) Run(bytecode []byte, execCtx *luaCtx.Context, scriptPath string) (any, error) { func (r *Runner) Run(bytecode []byte, execCtx *Context, scriptPath string) (*Response, error) {
return r.Execute(context.Background(), bytecode, execCtx, scriptPath) return r.Execute(context.Background(), bytecode, execCtx, scriptPath)
} }
@ -363,6 +286,7 @@ cleanup:
} }
} }
r.debugLog("Runner closed")
return nil return nil
} }
@ -375,6 +299,8 @@ func (r *Runner) RefreshStates() error {
return ErrRunnerClosed return ErrRunnerClosed
} }
r.debugLog("Refreshing all states...")
// Drain all states from the pool // Drain all states from the pool
for { for {
select { select {
@ -408,81 +334,6 @@ cleanup:
return nil return nil
} }
// AddInitHook adds a hook to be called before script execution
func (r *Runner) AddInitHook(hook InitHook) {
r.mu.Lock()
defer r.mu.Unlock()
r.initHooks = append(r.initHooks, hook)
}
// AddFinalizeHook adds a hook to be called after script execution
func (r *Runner) AddFinalizeHook(hook FinalizeHook) {
r.mu.Lock()
defer r.mu.Unlock()
r.finalizeHooks = append(r.finalizeHooks, hook)
}
// GetStateCount returns the number of initialized states
func (r *Runner) GetStateCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
count := 0
for _, state := range r.states {
if state != nil {
count++
}
}
return count
}
// GetActiveStateCount returns the number of states currently in use
func (r *Runner) GetActiveStateCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
count := 0
for _, state := range r.states {
if state != nil && state.inUse {
count++
}
}
return count
}
// GetModuleCount returns the number of loaded modules in the first available state
func (r *Runner) GetModuleCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
if !r.isRunning.Load() {
return 0
}
// Find first available state
for _, state := range r.states {
if state != nil && !state.inUse {
// Execute a Lua snippet to count modules
if res, err := state.L.ExecuteWithResult(`
local count = 0
for _ in pairs(package.loaded) do
count = count + 1
end
return count
`); err == nil {
if num, ok := res.(float64); ok {
return int(num)
}
}
break
}
}
return 0
}
// NotifyFileChanged alerts the runner about file changes // NotifyFileChanged alerts the runner about file changes
func (r *Runner) NotifyFileChanged(filePath string) bool { func (r *Runner) NotifyFileChanged(filePath string) bool {
r.debugLog("File change detected: %s", filePath) r.debugLog("File change detected: %s", filePath)
@ -514,9 +365,6 @@ func (r *Runner) RefreshModule(moduleName string) bool {
r.debugLog("Refreshing module: %s", moduleName) r.debugLog("Refreshing module: %s", moduleName)
// Check if it's a core module
coreName, isCore := GlobalRegistry.MatchModuleName(moduleName)
success := true success := true
for _, state := range r.states { for _, state := range r.states {
if state == nil || state.inUse { if state == nil || state.inUse {
@ -526,16 +374,39 @@ func (r *Runner) RefreshModule(moduleName string) bool {
// Invalidate module in Lua // Invalidate module in Lua
if err := state.L.DoString(`package.loaded["` + moduleName + `"] = nil`); err != nil { if err := state.L.DoString(`package.loaded["` + moduleName + `"] = nil`); err != nil {
success = false success = false
continue r.debugLog("Failed to invalidate module %s: %v", moduleName, err)
}
// For core modules, reinitialize them
if isCore {
if err := GlobalRegistry.InitializeModule(state.L, coreName); err != nil {
success = false
}
} }
} }
return success return success
} }
// GetStateCount returns the number of initialized states
func (r *Runner) GetStateCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
count := 0
for _, state := range r.states {
if state != nil {
count++
}
}
return count
}
// GetActiveStateCount returns the number of states currently in use
func (r *Runner) GetActiveStateCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
count := 0
for _, state := range r.states {
if state != nil && state.inUse {
count++
}
}
return count
}

345
core/runner/Sandbox.go Normal file
View File

@ -0,0 +1,345 @@
package runner
import (
"fmt"
"sync"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Error represents a simple error string
type Error string
func (e Error) Error() string {
return string(e)
}
// Error types
var (
ErrSandboxNotInitialized = Error("sandbox not initialized")
)
// Sandbox provides a secure execution environment for Lua scripts
type Sandbox struct {
modules map[string]any
debug bool
mu sync.RWMutex
}
// NewSandbox creates a new sandbox environment
func NewSandbox() *Sandbox {
return &Sandbox{
modules: make(map[string]any, 8),
debug: false,
}
}
// EnableDebug turns on debug logging
func (s *Sandbox) EnableDebug() {
s.debug = true
}
// debugLog logs a message if debug mode is enabled
func (s *Sandbox) debugLog(format string, args ...interface{}) {
if s.debug {
logger.Debug("Sandbox "+format, args...)
}
}
// AddModule adds a module to the sandbox environment
func (s *Sandbox) AddModule(name string, module any) {
s.mu.Lock()
defer s.mu.Unlock()
s.modules[name] = module
s.debugLog("Added module: %s", name)
}
// Setup initializes the sandbox in a Lua state
func (s *Sandbox) Setup(state *luajit.State) error {
s.debugLog("Setting up sandbox...")
// Load the sandbox code
if err := loadSandboxIntoState(state); err != nil {
s.debugLog("Failed to load sandbox: %v", err)
return err
}
// Register core functions
if err := s.registerCoreFunctions(state); err != nil {
s.debugLog("Failed to register core functions: %v", err)
return err
}
// Register custom modules in the global environment
s.mu.RLock()
for name, module := range s.modules {
s.debugLog("Registering module: %s", name)
if err := state.PushValue(module); err != nil {
s.mu.RUnlock()
s.debugLog("Failed to register module %s: %v", name, err)
return err
}
state.SetGlobal(name)
}
s.mu.RUnlock()
s.debugLog("Sandbox setup complete")
return nil
}
// registerCoreFunctions registers all built-in functions in the Lua state
func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
// Register HTTP functions
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
return err
}
// Register utility functions
if err := state.RegisterGoFunction("__generate_token", generateToken); err != nil {
return err
}
// Additional registrations can be added here
return nil
}
// Execute runs a Lua script in the sandbox with the given context
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
s.debugLog("Executing script...")
// Create a response object
response := NewResponse()
// Get a buffer for string operations
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
// Load bytecode
if err := state.LoadBytecode(bytecode, "script"); err != nil {
ReleaseResponse(response)
s.debugLog("Failed to load bytecode: %v", err)
return nil, fmt.Errorf("failed to load script: %w", err)
}
// Initialize session data in Lua
if ctx.SessionID != "" {
// Set session ID
state.PushString(ctx.SessionID)
state.SetGlobal("__session_id")
// Set session data
if err := state.PushTable(ctx.SessionData); err != nil {
ReleaseResponse(response)
s.debugLog("Failed to push session data: %v", err)
return nil, err
}
state.SetGlobal("__session_data")
// Reset modification flag
state.PushBoolean(false)
state.SetGlobal("__session_modified")
} else {
// Initialize empty session
if err := state.DoString("__session_data = {}; __session_modified = false"); err != nil {
s.debugLog("Failed to initialize empty session data: %v", err)
}
}
// Set up context values for execution
if err := state.PushTable(ctx.Values); err != nil {
ReleaseResponse(response)
s.debugLog("Failed to push context values: %v", err)
return nil, err
}
// Get the execution function
state.GetGlobal("__execute_script")
if !state.IsFunction(-1) {
state.Pop(1) // Pop non-function
ReleaseResponse(response)
s.debugLog("__execute_script is not a function")
return nil, ErrSandboxNotInitialized
}
// Push function and context to stack
state.PushCopy(-2) // bytecode
state.PushCopy(-2) // context
// Remove duplicates
state.Remove(-4)
state.Remove(-3)
// Execute with 2 args, 1 result
if err := state.Call(2, 1); err != nil {
ReleaseResponse(response)
s.debugLog("Execution failed: %v", err)
return nil, fmt.Errorf("script execution failed: %w", err)
}
// Set response body from result
body, err := state.ToValue(-1)
if err == nil {
response.Body = body
}
state.Pop(1)
// Extract HTTP response data from Lua state
s.extractResponseData(state, response)
return response, nil
}
// extractResponseData pulls response info from the Lua state
func (s *Sandbox) extractResponseData(state *luajit.State, response *Response) {
// Get HTTP response
state.GetGlobal("__http_responses")
if !state.IsNil(-1) && state.IsTable(-1) {
state.PushNumber(1)
state.GetTable(-2)
if !state.IsNil(-1) && state.IsTable(-1) {
// Extract status
state.GetField(-1, "status")
if state.IsNumber(-1) {
response.Status = int(state.ToNumber(-1))
}
state.Pop(1)
// Extract headers
state.GetField(-1, "headers")
if state.IsTable(-1) {
state.PushNil() // Start iteration
for state.Next(-2) {
if state.IsString(-2) && state.IsString(-1) {
key := state.ToString(-2)
value := state.ToString(-1)
response.Headers[key] = value
}
state.Pop(1)
}
}
state.Pop(1)
// Extract cookies
state.GetField(-1, "cookies")
if state.IsTable(-1) {
length := state.GetTableLength(-1)
for i := 1; i <= length; i++ {
state.PushNumber(float64(i))
state.GetTable(-2)
if state.IsTable(-1) {
s.extractCookie(state, response)
}
state.Pop(1)
}
}
state.Pop(1)
// Extract metadata if present
state.GetField(-1, "metadata")
if state.IsTable(-1) {
table, err := state.ToTable(-1)
if err == nil {
for k, v := range table {
response.Metadata[k] = v
}
}
}
state.Pop(1)
}
state.Pop(1)
}
state.Pop(1)
// Extract session data
state.GetGlobal("__session_modified")
if state.IsBoolean(-1) && state.ToBoolean(-1) {
response.SessionModified = true
// Get session ID
state.GetGlobal("__session_id")
if state.IsString(-1) {
response.SessionID = state.ToString(-1)
}
state.Pop(1)
// Get session data
state.GetGlobal("__session_data")
if state.IsTable(-1) {
sessionData, err := state.ToTable(-1)
if err == nil {
for k, v := range sessionData {
response.SessionData[k] = v
}
}
}
state.Pop(1)
}
state.Pop(1)
}
// extractCookie pulls cookie data from the current table on the stack
func (s *Sandbox) extractCookie(state *luajit.State, response *Response) {
cookie := fasthttp.AcquireCookie()
// Get name (required)
state.GetField(-1, "name")
if !state.IsString(-1) {
state.Pop(1)
fasthttp.ReleaseCookie(cookie)
return
}
cookie.SetKey(state.ToString(-1))
state.Pop(1)
// Get value
state.GetField(-1, "value")
if state.IsString(-1) {
cookie.SetValue(state.ToString(-1))
}
state.Pop(1)
// Get path
state.GetField(-1, "path")
if state.IsString(-1) {
cookie.SetPath(state.ToString(-1))
} else {
cookie.SetPath("/") // Default
}
state.Pop(1)
// Get domain
state.GetField(-1, "domain")
if state.IsString(-1) {
cookie.SetDomain(state.ToString(-1))
}
state.Pop(1)
// Get other parameters
state.GetField(-1, "http_only")
if state.IsBoolean(-1) {
cookie.SetHTTPOnly(state.ToBoolean(-1))
}
state.Pop(1)
state.GetField(-1, "secure")
if state.IsBoolean(-1) {
cookie.SetSecure(state.ToBoolean(-1))
}
state.Pop(1)
state.GetField(-1, "max_age")
if state.IsNumber(-1) {
cookie.SetMaxAge(int(state.ToNumber(-1)))
}
state.Pop(1)
response.Cookies = append(response.Cookies, cookie)
}

View File

@ -1,241 +0,0 @@
package runner
import (
luaCtx "Moonshark/core/runner/context"
"Moonshark/core/runner/sandbox"
"Moonshark/core/sessions"
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/valyala/fasthttp"
)
// SessionHandler handles session management for Lua scripts
type SessionHandler struct {
manager *sessions.SessionManager
debugLog bool
}
// NewSessionHandler creates a new session handler
func NewSessionHandler(manager *sessions.SessionManager) *SessionHandler {
return &SessionHandler{
manager: manager,
debugLog: false,
}
}
// EnableDebug enables debug logging
func (h *SessionHandler) EnableDebug() {
h.debugLog = true
}
// WithSessionManager creates a RunnerOption to add session support
func WithSessionManager(manager *sessions.SessionManager) RunnerOption {
return func(r *Runner) {
handler := NewSessionHandler(manager)
r.AddInitHook(handler.preRequestHook)
r.AddFinalizeHook(handler.postRequestHook)
}
}
// preRequestHook initializes session before script execution
func (h *SessionHandler) preRequestHook(state *luajit.State, ctx *luaCtx.Context) error {
if ctx == nil || ctx.Values["_request_cookies"] == nil {
return nil
}
// Extract cookies from context
cookies, ok := ctx.Values["_request_cookies"].(map[string]any)
if !ok {
return nil
}
// Get the session ID from cookies
cookieName := h.manager.CookieOptions()["name"].(string)
var sessionID string
// Check if our session cookie exists
if cookieValue, exists := cookies[cookieName]; exists {
if strValue, ok := cookieValue.(string); ok && strValue != "" {
sessionID = strValue
}
}
// Create new session if needed
if sessionID == "" {
session := h.manager.CreateSession()
sessionID = session.ID
}
// Store the session ID in the context
ctx.Set("_session_id", sessionID)
// Get session data
session := h.manager.GetSession(sessionID)
sessionData := session.GetAll()
// Set session data in Lua state
return SetSessionData(state, sessionID, sessionData)
}
// postRequestHook handles session after script execution
func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *luaCtx.Context, result any) error {
// Check if session was modified
modifiedID, modifiedData, modified := GetSessionData(state)
if !modified {
return nil
}
// Get the original session ID from context
var sessionID string
if ctx != nil {
if id, ok := ctx.Values["_session_id"].(string); ok {
sessionID = id
}
}
// Use the original session ID if the modified one is empty
if modifiedID == "" {
modifiedID = sessionID
}
if modifiedID == "" {
return nil
}
// Update session in manager
session := h.manager.GetSession(modifiedID)
session.Clear() // clear to sync deleted values
for k, v := range modifiedData {
session.Set(k, v)
}
h.manager.SaveSession(session)
// Add session cookie to result if it's an HTTP response
if httpResp, ok := result.(*sandbox.HTTPResponse); ok {
h.addSessionCookie(httpResp, modifiedID)
} else if ctx != nil && ctx.RequestCtx != nil {
// Add cookie directly to the RequestCtx when result is not an HTTP response
h.addSessionCookieToRequestCtx(ctx.RequestCtx, modifiedID)
}
return nil
}
// addSessionCookie adds a session cookie to an HTTP response
func (h *SessionHandler) addSessionCookie(resp *sandbox.HTTPResponse, sessionID string) {
// Get cookie options
opts := h.manager.CookieOptions()
// Check if session cookie is already set
cookieName := opts["name"].(string)
for _, cookie := range resp.Cookies {
if string(cookie.Key()) == cookieName {
return
}
}
// Create and add cookie
cookie := fasthttp.AcquireCookie()
cookie.SetKey(cookieName)
cookie.SetValue(sessionID)
cookie.SetPath(opts["path"].(string))
cookie.SetHTTPOnly(opts["http_only"].(bool))
cookie.SetMaxAge(opts["max_age"].(int))
// Optional cookie parameters
if domain, ok := opts["domain"].(string); ok && domain != "" {
cookie.SetDomain(domain)
}
if secure, ok := opts["secure"].(bool); ok {
cookie.SetSecure(secure)
}
resp.Cookies = append(resp.Cookies, cookie)
}
func (h *SessionHandler) addSessionCookieToRequestCtx(ctx *fasthttp.RequestCtx, sessionID string) {
// Get cookie options
opts := h.manager.CookieOptions()
cookieName := opts["name"].(string)
// Create cookie
cookie := fasthttp.AcquireCookie()
defer fasthttp.ReleaseCookie(cookie)
cookie.SetKey(cookieName)
cookie.SetValue(sessionID)
cookie.SetPath(opts["path"].(string))
cookie.SetHTTPOnly(opts["http_only"].(bool))
cookie.SetMaxAge(opts["max_age"].(int))
// Optional cookie parameters
if domain, ok := opts["domain"].(string); ok && domain != "" {
cookie.SetDomain(domain)
}
if secure, ok := opts["secure"].(bool); ok {
cookie.SetSecure(secure)
}
ctx.Response.Header.SetCookie(cookie)
}
// GetSessionData extracts session data from Lua state
func GetSessionData(state *luajit.State) (string, map[string]any, bool) {
// Check if session was modified
state.GetGlobal("__session_modified")
modified := state.ToBoolean(-1)
state.Pop(1)
if !modified {
return "", nil, false
}
// Get session ID
state.GetGlobal("__session_id")
sessionID := state.ToString(-1)
state.Pop(1)
// Get session data
state.GetGlobal("__session_data")
if !state.IsTable(-1) {
state.Pop(1)
return sessionID, nil, false
}
data, err := state.ToTable(-1)
state.Pop(1)
if err != nil {
logger.Error("Failed to extract session data: %v", err)
return sessionID, nil, false
}
return sessionID, data, true
}
// SetSessionData sets session data in Lua state
func SetSessionData(state *luajit.State, sessionID string, data map[string]any) error {
// Set session ID
state.PushString(sessionID)
state.SetGlobal("__session_id")
// Set session data
if data == nil {
data = make(map[string]any)
}
if err := state.PushTable(data); err != nil {
return err
}
state.SetGlobal("__session_data")
// Reset modification flag
state.PushBoolean(false)
state.SetGlobal("__session_modified")
return nil
}

View File

@ -14,9 +14,6 @@ __ready_modules = {}
__session_data = {} __session_data = {}
__session_id = nil __session_id = nil
__session_modified = false __session_modified = false
__env_system = {
base_env = {}
}
-- ====================================================================== -- ======================================================================
-- CORE SANDBOX FUNCTIONALITY -- CORE SANDBOX FUNCTIONALITY
@ -44,7 +41,7 @@ end
function __execute_script(fn, ctx) function __execute_script(fn, ctx)
-- Clear previous responses -- Clear previous responses
__http_responses[1] = nil __http_responses[1] = nil
-- Reset session modification flag -- Reset session modification flag
__session_modified = false __session_modified = false
@ -63,75 +60,6 @@ function __execute_script(fn, ctx)
return result return result
end end
-- ======================================================================
-- MODULE LOADING SYSTEM
-- ======================================================================
-- Setup environment-aware require function
function __setup_require(env)
-- Create require function specific to this environment
env.require = function(modname)
-- Check if already loaded
if package.loaded[modname] then
return package.loaded[modname]
end
-- Check preloaded modules
if __ready_modules[modname] then
local loader = package.preload[modname]
if loader then
-- Set environment for loader
setfenv(loader, env)
-- Execute and store result
local result = loader()
if result == nil then
result = true
end
package.loaded[modname] = result
return result
end
end
-- Direct file load as fallback
if __module_paths[modname] then
local path = __module_paths[modname]
local chunk, err = loadfile(path)
if chunk then
setfenv(chunk, env)
local result = chunk()
if result == nil then
result = true
end
package.loaded[modname] = result
return result
end
end
-- Full path search as last resort
local errors = {}
for path in package.path:gmatch("[^;]+") do
local file_path = path:gsub("?", modname:gsub("%.", "/"))
local chunk, err = loadfile(file_path)
if chunk then
setfenv(chunk, env)
local result = chunk()
if result == nil then
result = true
end
package.loaded[modname] = result
return result
end
table.insert(errors, "\tno file '" .. file_path .. "'")
end
error("module '" .. modname .. "' not found:\n" .. table.concat(errors, "\n"), 2)
end
return env
end
-- ====================================================================== -- ======================================================================
-- HTTP MODULE -- HTTP MODULE
-- ====================================================================== -- ======================================================================
@ -166,6 +94,18 @@ local http = {
http.set_header("Content-Type", content_type) http.set_header("Content-Type", content_type)
end, end,
-- Set metadata (arbitrary data to be returned with response)
set_metadata = function(key, value)
if type(key) ~= "string" then
error("http.set_metadata: key must be a string", 2)
end
local resp = __http_responses[1] or {}
resp.metadata = resp.metadata or {}
resp.metadata[key] = value
__http_responses[1] = resp
end,
-- HTTP client submodule -- HTTP client submodule
client = { client = {
-- Generic request function -- Generic request function
@ -213,10 +153,7 @@ local http = {
-- Simple HEAD request -- Simple HEAD request
head = function(url, options) head = function(url, options)
options = options or {} options = options or {}
local old_options = options return http.client.request("HEAD", url, nil, options)
options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query}
local response = http.client.request("HEAD", url, nil, options)
return response
end, end,
-- Simple OPTIONS request -- Simple OPTIONS request
@ -265,13 +202,13 @@ local http = {
} }
-- ====================================================================== -- ======================================================================
-- COOKIE MODULE -- COOKIE MODULE
-- ====================================================================== -- ======================================================================
-- Cookie module implementation -- Cookie module implementation
local cookie = { local cookie = {
-- Set a cookie -- Set a cookie
set = function(name, value, options, ...) set = function(name, value, options)
if type(name) ~= "string" then if type(name) ~= "string" then
error("cookie.set: name must be a string", 2) error("cookie.set: name must be a string", 2)
end end
@ -281,20 +218,8 @@ local cookie = {
resp.cookies = resp.cookies or {} resp.cookies = resp.cookies or {}
__http_responses[1] = resp __http_responses[1] = resp
-- Handle options as table or legacy params -- Handle options as table
local opts = {} local opts = options or {}
if type(options) == "table" then
opts = options
elseif options ~= nil then
-- Legacy support: options is actually 'expires'
opts.expires = options
-- Check for other legacy params (4th-7th args)
local args = {...}
if args[1] then opts.path = args[1] end
if args[2] then opts.domain = args[2] end
if args[3] then opts.secure = args[3] end
if args[4] ~= nil then opts.http_only = args[4] end
end
-- Create cookie table -- Create cookie table
local cookie = { local cookie = {
@ -314,10 +239,8 @@ local cookie = {
elseif opts.expires < 0 then elseif opts.expires < 0 then
cookie.expires = 1 cookie.expires = 1
cookie.max_age = 0 cookie.max_age = 0
else
-- opts.expires == 0: Session cookie
-- Do nothing (omitting both expires and max-age creates a session cookie)
end end
-- opts.expires == 0: Session cookie (omitting both expires and max-age)
end end
end end
@ -342,8 +265,13 @@ local cookie = {
local env = getfenv(2) local env = getfenv(2)
-- Check if context exists and has cookies -- Check if context exists and has cookies
if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then if env.ctx and env.ctx.cookies then
return tostring(env.ctx.cookies[name]) return env.ctx.cookies[name]
end
-- If context has request_cookies map
if env.ctx and env.ctx._request_cookies then
return env.ctx._request_cookies[name]
end end
return nil return nil
@ -361,7 +289,7 @@ local cookie = {
} }
-- ====================================================================== -- ======================================================================
-- SESSION MODULE -- SESSION MODULE
-- ====================================================================== -- ======================================================================
-- Session module implementation -- Session module implementation
@ -372,7 +300,7 @@ local session = {
error("session.get: key must be a string", 2) error("session.get: key must be a string", 2)
end end
if __session_data and __session_data[key] then if __session_data and __session_data[key] ~= nil then
return __session_data[key] return __session_data[key]
end end
@ -469,7 +397,7 @@ local csrf = {
error("CSRF protection requires the session module", 2) error("CSRF protection requires the session module", 2)
end end
local token = util.generate_token(length) local token = __generate_token(length)
session.set(csrf.TOKEN_KEY, token) session.set(csrf.TOKEN_KEY, token)
return token return token
end, end,
@ -495,48 +423,133 @@ local csrf = {
end, end,
-- Verify a given token against the session token -- Verify a given token against the session token
verify = function(token, field_name) verify = function(token, field_name)
field_name = field_name or csrf.DEFAULT_FIELD field_name = field_name or csrf.DEFAULT_FIELD
local env = getfenv(2) local env = getfenv(2)
local form = nil local form = nil
if env.ctx and env.ctx.form then if env.ctx and env.ctx._request_form then
form = env.ctx.form form = env.ctx._request_form
else elseif env.ctx and env.ctx.form then
return false form = env.ctx.form
end else
return false
end
token = token or form[field_name] token = token or form[field_name]
if not token then if not token then
return false return false
end end
local session_token = session.get(csrf.TOKEN_KEY) local session_token = session.get(csrf.TOKEN_KEY)
if not session_token then if not session_token then
return false return false
end end
-- Constant-time comparison to prevent timing attacks -- Constant-time comparison to prevent timing attacks
-- This is safe since Lua strings are immutable if #token ~= #session_token then
if #token ~= #session_token then return false
return false end
end
local result = true local result = true
for i = 1, #token do for i = 1, #token do
if token:sub(i, i) ~= session_token:sub(i, i) then if token:sub(i, i) ~= session_token:sub(i, i) then
result = false result = false
-- Don't break early - continue to prevent timing attacks -- Don't break early - continue to prevent timing attacks
end end
end end
return result return result
end end
} }
-- ====================================================================== -- ======================================================================
-- REGISTER MODULES GLOBALLY -- UTIL MODULE
-- ======================================================================
-- Utility module implementation
local util = {
-- Generate a token (wrapper around __generate_token)
generate_token = function(length)
return __generate_token(length or 32)
end,
-- Simple JSON stringify (for when you just need a quick string)
json_encode = function(value)
if type(value) == "table" then
local json = "{"
local sep = ""
for k, v in pairs(value) do
json = json .. sep
if type(k) == "number" then
-- Array-like
json = json .. util.json_encode(v)
else
-- Object-like
json = json .. '"' .. k .. '":' .. util.json_encode(v)
end
sep = ","
end
return json .. "}"
elseif type(value) == "string" then
return '"' .. value:gsub('"', '\\"'):gsub('\n', '\\n') .. '"'
elseif type(value) == "number" then
return tostring(value)
elseif type(value) == "boolean" then
return value and "true" or "false"
elseif value == nil then
return "null"
end
return '"' .. tostring(value) .. '"'
end,
-- Deep copy of tables
deep_copy = function(obj)
if type(obj) ~= 'table' then return obj end
local res = {}
for k, v in pairs(obj) do res[k] = util.deep_copy(v) end
return res
end,
-- Merge tables
merge_tables = function(t1, t2)
if type(t1) ~= 'table' or type(t2) ~= 'table' then
error("Both arguments must be tables", 2)
end
local result = util.deep_copy(t1)
for k, v in pairs(t2) do
if type(v) == 'table' and type(result[k]) == 'table' then
result[k] = util.merge_tables(result[k], v)
else
result[k] = v
end
end
return result
end,
-- String utilities
string = {
-- Trim whitespace
trim = function(s)
return (s:gsub("^%s*(.-)%s*$", "%1"))
end,
-- Split string
split = function(s, delimiter)
delimiter = delimiter or ","
local result = {}
for match in (s..delimiter):gmatch("(.-)"..delimiter) do
table.insert(result, match)
end
return result
end
}
}
-- ======================================================================
-- REGISTER MODULES GLOBALLY
-- ====================================================================== -- ======================================================================
-- Install modules in global scope -- Install modules in global scope
@ -544,9 +557,4 @@ _G.http = http
_G.cookie = cookie _G.cookie = cookie
_G.session = session _G.session = session
_G.csrf = csrf _G.csrf = csrf
_G.util = util
-- Register modules in sandbox base environment
__env_system.base_env.http = http
__env_system.base_env.cookie = cookie
__env_system.base_env.session = session
__env_system.base_env.csrf = csrf

View File

@ -1,98 +0,0 @@
package sandbox
import (
_ "embed"
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
//go:embed lua/sandbox.lua
var sandboxLua string
// InitializeSandbox loads the embedded Lua sandbox code into a Lua state
func InitializeSandbox(state *luajit.State) error {
// Compile once, use many times
bytecodeOnce.Do(precompileSandbox)
if sandboxBytecode != nil {
logger.Debug("Loading sandbox.lua from precompiled bytecode")
return state.LoadAndRunBytecode(sandboxBytecode, "sandbox.lua")
}
// Fallback if compilation failed
logger.Warning("Using non-precompiled sandbox.lua (bytecode compilation failed)")
return state.DoString(sandboxLua)
}
// ModuleInitializers stores initializer functions for core modules
type ModuleInitializers struct {
HTTP func(*luajit.State) error
Util func(*luajit.State) error
Session func(*luajit.State) error
Cookie func(*luajit.State) error
CSRF func(*luajit.State) error
}
// DefaultInitializers returns the default set of initializers
func DefaultInitializers() *ModuleInitializers {
return &ModuleInitializers{
HTTP: func(state *luajit.State) error {
// Register the native Go function first
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
logger.Error("[HTTP Module] Failed to register __http_request function: %v", err)
return err
}
return nil
},
Util: func(state *luajit.State) error {
// Register util functions
return RegisterModule(state, "util", UtilModuleFunctions())
},
Session: func(state *luajit.State) error {
// Session doesn't need special initialization
return nil
},
Cookie: func(state *luajit.State) error {
// Cookie doesn't need special initialization
return nil
},
CSRF: func(state *luajit.State) error {
// CSRF doesn't need special initialization
return nil
},
}
}
// InitializeAll initializes all modules in the Lua state
func InitializeAll(state *luajit.State, initializers *ModuleInitializers) error {
// Set up dependencies first
if err := initializers.Util(state); err != nil {
return err
}
if err := initializers.HTTP(state); err != nil {
return err
}
// Load the embedded sandbox code
if err := InitializeSandbox(state); err != nil {
return err
}
// Initialize the rest of the modules
if err := initializers.Session(state); err != nil {
return err
}
if err := initializers.Cookie(state); err != nil {
return err
}
if err := initializers.CSRF(state); err != nil {
return err
}
return nil
}

View File

@ -1,590 +0,0 @@
package sandbox
import (
"context"
"errors"
"fmt"
"net/url"
"strings"
"sync"
"time"
"github.com/goccy/go-json"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// SessionHandler interface for session management
type SessionHandler interface {
LoadSession(ctx *fasthttp.RequestCtx) (string, map[string]any)
SaveSession(ctx *fasthttp.RequestCtx, sessionID string, data map[string]any) bool
}
// HTTPResponse represents an HTTP response from Lua
type HTTPResponse struct {
Status int `json:"status"`
Headers map[string]string `json:"headers"`
Body any `json:"body"`
Cookies []*fasthttp.Cookie `json:"-"`
SessionModified bool `json:"-"`
}
// Response pool to reduce allocations
var responsePool = sync.Pool{
New: func() any {
return &HTTPResponse{
Status: 200,
Headers: make(map[string]string, 8),
Cookies: make([]*fasthttp.Cookie, 0, 4),
}
},
}
// Default HTTP client with sensible timeout
var defaultFastClient fasthttp.Client = fasthttp.Client{
MaxConnsPerHost: 1024,
MaxIdleConnDuration: time.Minute,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
DisableHeaderNamesNormalizing: true,
}
// HTTPClientConfig contains client settings
type HTTPClientConfig struct {
MaxTimeout time.Duration // Maximum timeout for requests (0 = no limit)
DefaultTimeout time.Duration // Default request timeout
MaxResponseSize int64 // Maximum response size in bytes (0 = no limit)
AllowRemote bool // Whether to allow remote connections
}
// DefaultHTTPClientConfig provides sensible defaults
var DefaultHTTPClientConfig = HTTPClientConfig{
MaxTimeout: 60 * time.Second,
DefaultTimeout: 30 * time.Second,
MaxResponseSize: 10 * 1024 * 1024, // 10MB
AllowRemote: true,
}
// NewHTTPResponse creates a default HTTP response from pool
func NewHTTPResponse() *HTTPResponse {
return responsePool.Get().(*HTTPResponse)
}
// ReleaseResponse returns the response to the pool
func ReleaseResponse(resp *HTTPResponse) {
if resp == nil {
return
}
// Clear all values to prevent data leakage
resp.Status = 200 // Reset to default
// Clear headers
for k := range resp.Headers {
delete(resp.Headers, k)
}
// Clear cookies
resp.Cookies = resp.Cookies[:0] // Keep capacity but set length to 0
// Reset session flag
resp.SessionModified = false
// Clear body
resp.Body = nil
responsePool.Put(resp)
}
// HTTPModuleInitFunc returns an initializer function for the HTTP module
func HTTPModuleInitFunc() func(*luajit.State) error {
return func(state *luajit.State) error {
// Register the native Go function first
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
logger.Error("[HTTP Module] Failed to register __http_request function: %v", err)
return err
}
// Set up default HTTP client configuration
setupHTTPClientConfig(state)
return nil
}
}
// setupHTTPClientConfig configures HTTP client in Lua
func setupHTTPClientConfig(state *luajit.State) {
state.NewTable()
state.PushNumber(float64(DefaultHTTPClientConfig.MaxTimeout / time.Second))
state.SetField(-2, "max_timeout")
state.PushNumber(float64(DefaultHTTPClientConfig.DefaultTimeout / time.Second))
state.SetField(-2, "default_timeout")
state.PushNumber(float64(DefaultHTTPClientConfig.MaxResponseSize))
state.SetField(-2, "max_response_size")
state.PushBoolean(DefaultHTTPClientConfig.AllowRemote)
state.SetField(-2, "allow_remote")
state.SetGlobal("__http_client_config")
}
// GetHTTPResponse extracts the HTTP response from Lua state
func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) {
response := NewHTTPResponse()
// Get response table
state.GetGlobal("__http_responses")
if state.IsNil(-1) {
state.Pop(1)
ReleaseResponse(response)
return nil, false
}
// Check for response at thread index
state.PushNumber(1)
state.GetTable(-2)
if state.IsNil(-1) {
state.Pop(2)
ReleaseResponse(response)
return nil, false
}
// Get status
state.GetField(-1, "status")
if state.IsNumber(-1) {
response.Status = int(state.ToNumber(-1))
}
state.Pop(1)
// Get headers
state.GetField(-1, "headers")
if state.IsTable(-1) {
// Iterate through headers table
state.PushNil() // Start iteration
for state.Next(-2) {
// Stack has key at -2 and value at -1
if state.IsString(-2) && state.IsString(-1) {
key := state.ToString(-2)
value := state.ToString(-1)
response.Headers[key] = value
}
state.Pop(1) // Pop value, leave key for next iteration
}
}
state.Pop(1)
// Get cookies
state.GetField(-1, "cookies")
if state.IsTable(-1) {
// Iterate through cookies array
length := state.GetTableLength(-1)
for i := 1; i <= length; i++ {
state.PushNumber(float64(i))
state.GetTable(-2)
if state.IsTable(-1) {
cookie := extractCookie(state)
if cookie != nil {
response.Cookies = append(response.Cookies, cookie)
}
}
state.Pop(1)
}
}
state.Pop(1)
// Check if session was modified
state.GetGlobal("__session_modified")
if state.IsBoolean(-1) && state.ToBoolean(-1) {
response.SessionModified = true
}
state.Pop(1)
// Clean up
state.Pop(2) // Pop response table and __http_responses
return response, true
}
// ApplyHTTPResponse applies an HTTP response to a fasthttp.RequestCtx
func ApplyHTTPResponse(httpResp *HTTPResponse, ctx *fasthttp.RequestCtx) {
// Set status code
ctx.SetStatusCode(httpResp.Status)
// Set headers
for name, value := range httpResp.Headers {
ctx.Response.Header.Set(name, value)
}
// Set cookies
for _, cookie := range httpResp.Cookies {
ctx.Response.Header.SetCookie(cookie)
}
// Process the body based on its type
if httpResp.Body == nil {
return
}
// Set body based on type
switch body := httpResp.Body.(type) {
case string:
ctx.SetBodyString(body)
case []byte:
ctx.SetBody(body)
case map[string]any, []any, []float64, []string, []int:
// Marshal JSON using a buffer from the pool
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
if err := json.NewEncoder(buf).Encode(body); err == nil {
// Set content type if not already set
if len(ctx.Response.Header.ContentType()) == 0 {
ctx.Response.Header.SetContentType("application/json")
}
ctx.SetBody(buf.Bytes())
} else {
// Fallback
ctx.SetBodyString(fmt.Sprintf("%v", body))
}
default:
// Default to string representation
ctx.SetBodyString(fmt.Sprintf("%v", body))
}
}
// extractCookie grabs cookies from the Lua state
func extractCookie(state *luajit.State) *fasthttp.Cookie {
cookie := fasthttp.AcquireCookie()
// Get name
state.GetField(-1, "name")
if !state.IsString(-1) {
state.Pop(1)
fasthttp.ReleaseCookie(cookie)
return nil // Name is required
}
cookie.SetKey(state.ToString(-1))
state.Pop(1)
// Get value
state.GetField(-1, "value")
if state.IsString(-1) {
cookie.SetValue(state.ToString(-1))
}
state.Pop(1)
// Get path
state.GetField(-1, "path")
if state.IsString(-1) {
cookie.SetPath(state.ToString(-1))
} else {
cookie.SetPath("/") // Default path
}
state.Pop(1)
// Get domain
state.GetField(-1, "domain")
if state.IsString(-1) {
cookie.SetDomain(state.ToString(-1))
}
state.Pop(1)
// Get expires
state.GetField(-1, "expires")
if state.IsNumber(-1) {
expiry := int64(state.ToNumber(-1))
cookie.SetExpire(time.Unix(expiry, 0))
}
state.Pop(1)
// Get max age
state.GetField(-1, "max_age")
if state.IsNumber(-1) {
cookie.SetMaxAge(int(state.ToNumber(-1)))
}
state.Pop(1)
// Get secure
state.GetField(-1, "secure")
if state.IsBoolean(-1) {
cookie.SetSecure(state.ToBoolean(-1))
}
state.Pop(1)
// Get http only
state.GetField(-1, "http_only")
if state.IsBoolean(-1) {
cookie.SetHTTPOnly(state.ToBoolean(-1))
}
state.Pop(1)
return cookie
}
// httpRequest makes an HTTP request and returns the result to Lua
func httpRequest(state *luajit.State) int {
// Get method (required)
if !state.IsString(1) {
state.PushString("http.client.request: method must be a string")
return -1
}
method := strings.ToUpper(state.ToString(1))
// Get URL (required)
if !state.IsString(2) {
state.PushString("http.client.request: url must be a string")
return -1
}
urlStr := state.ToString(2)
// Parse URL to check if it's valid and if it's allowed
parsedURL, err := url.Parse(urlStr)
if err != nil {
state.PushString("Invalid URL: " + err.Error())
return -1
}
// Get client configuration
var config HTTPClientConfig = DefaultHTTPClientConfig
state.GetGlobal("__http_client_config")
if !state.IsNil(-1) && state.IsTable(-1) {
// Extract max timeout
state.GetField(-1, "max_timeout")
if state.IsNumber(-1) {
config.MaxTimeout = time.Duration(state.ToNumber(-1)) * time.Second
}
state.Pop(1)
// Extract default timeout
state.GetField(-1, "default_timeout")
if state.IsNumber(-1) {
config.DefaultTimeout = time.Duration(state.ToNumber(-1)) * time.Second
}
state.Pop(1)
// Extract max response size
state.GetField(-1, "max_response_size")
if state.IsNumber(-1) {
config.MaxResponseSize = int64(state.ToNumber(-1))
}
state.Pop(1)
// Extract allow remote
state.GetField(-1, "allow_remote")
if state.IsBoolean(-1) {
config.AllowRemote = state.ToBoolean(-1)
}
state.Pop(1)
}
state.Pop(1)
// Check if remote connections are allowed
if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") {
state.PushString("Remote connections are not allowed")
return -1
}
// Use bytebufferpool for request and response
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// Set up request
req.Header.SetMethod(method)
req.SetRequestURI(urlStr)
req.Header.Set("User-Agent", "Moonshark/1.0")
// Get body (optional)
if state.GetTop() >= 3 && !state.IsNil(3) {
if state.IsString(3) {
// String body
req.SetBodyString(state.ToString(3))
} else if state.IsTable(3) {
// Table body - convert to JSON
luaTable, err := state.ToTable(3)
if err != nil {
state.PushString("Failed to parse body table: " + err.Error())
return -1
}
// Use bytebufferpool for JSON serialization
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
if err := json.NewEncoder(buf).Encode(luaTable); err != nil {
state.PushString("Failed to convert body to JSON: " + err.Error())
return -1
}
req.SetBody(buf.Bytes())
} else {
state.PushString("Body must be a string or table")
return -1
}
}
// Process options (headers, timeout, etc.)
timeout := config.DefaultTimeout
if state.GetTop() >= 4 && !state.IsNil(4) {
if !state.IsTable(4) {
state.PushString("Options must be a table")
return -1
}
// Process headers
state.GetField(4, "headers")
if state.IsTable(-1) {
// Iterate through headers
state.PushNil() // Start iteration
for state.Next(-2) {
// Stack now has key at -2 and value at -1
if state.IsString(-2) && state.IsString(-1) {
headerName := state.ToString(-2)
headerValue := state.ToString(-1)
req.Header.Set(headerName, headerValue)
}
state.Pop(1) // Pop value, leave key for next iteration
}
}
state.Pop(1) // Pop headers table
// Get timeout
state.GetField(4, "timeout")
if state.IsNumber(-1) {
requestTimeout := time.Duration(state.ToNumber(-1)) * time.Second
// Apply max timeout if configured
if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout {
timeout = config.MaxTimeout
} else {
timeout = requestTimeout
}
}
state.Pop(1) // Pop timeout
// Set content type for POST/PUT if body is present and content-type not manually set
if (method == "POST" || method == "PUT") && req.Body() != nil && req.Header.Peek("Content-Type") == nil {
// Check if options specify content type
state.GetField(4, "content_type")
if state.IsString(-1) {
req.Header.Set("Content-Type", state.ToString(-1))
} else {
// Default to JSON if body is a table, otherwise plain text
if state.IsTable(3) {
req.Header.Set("Content-Type", "application/json")
} else {
req.Header.Set("Content-Type", "text/plain")
}
}
state.Pop(1) // Pop content_type
}
// Process query parameters
state.GetField(4, "query")
if state.IsTable(-1) {
// Create URL args
args := req.URI().QueryArgs()
// Iterate through query params
state.PushNil() // Start iteration
for state.Next(-2) {
// Stack now has key at -2 and value at -1
if state.IsString(-2) {
paramName := state.ToString(-2)
// Handle different value types
if state.IsString(-1) {
args.Add(paramName, state.ToString(-1))
} else if state.IsNumber(-1) {
args.Add(paramName, strings.TrimRight(strings.TrimRight(
state.ToString(-1), "0"), "."))
} else if state.IsBoolean(-1) {
if state.ToBoolean(-1) {
args.Add(paramName, "true")
} else {
args.Add(paramName, "false")
}
}
}
state.Pop(1) // Pop value, leave key for next iteration
}
}
state.Pop(1) // Pop query table
}
// Create context with timeout
_, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// Execute request
err = defaultFastClient.DoTimeout(req, resp, timeout)
if err != nil {
errStr := "Request failed: " + err.Error()
if errors.Is(err, fasthttp.ErrTimeout) {
errStr = "Request timed out after " + timeout.String()
}
state.PushString(errStr)
return -1
}
// Create response table
state.NewTable()
// Set status code
state.PushNumber(float64(resp.StatusCode()))
state.SetField(-2, "status")
// Set status text
statusText := fasthttp.StatusMessage(resp.StatusCode())
state.PushString(statusText)
state.SetField(-2, "status_text")
// Set body
var respBody []byte
// Apply size limits to response
if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize {
// Make a limited copy
respBody = make([]byte, config.MaxResponseSize)
copy(respBody, resp.Body())
} else {
respBody = resp.Body()
}
state.PushString(string(respBody))
state.SetField(-2, "body")
// Parse body as JSON if content type is application/json
contentType := string(resp.Header.ContentType())
if strings.Contains(contentType, "application/json") {
var jsonData any
if err := json.Unmarshal(respBody, &jsonData); err == nil {
if err := state.PushValue(jsonData); err == nil {
state.SetField(-2, "json")
}
}
}
// Set headers
state.NewTable()
resp.Header.VisitAll(func(key, value []byte) {
state.PushString(string(value))
state.SetField(-2, string(key))
})
state.SetField(-2, "headers")
// Create ok field (true if status code is 2xx)
state.PushBoolean(resp.StatusCode() >= 200 && resp.StatusCode() < 300)
state.SetField(-2, "ok")
return 1
}

View File

@ -1,84 +0,0 @@
package sandbox
import (
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// ModuleFunc returns a map of module functions
type ModuleFunc func() map[string]luajit.GoFunction
// StateInitFunc initializes a module in a Lua state
type StateInitFunc func(*luajit.State) error
// RegisterModule registers a map of functions as a Lua module
func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.GoFunction) error {
// Create a new table for the module
state.NewTable()
// Add each function to the module table
for fname, f := range funcs {
state.PushString(fname)
if err := state.PushGoFunction(f); err != nil {
state.Pop(1) // Pop table
return err
}
state.SetTable(-3)
}
// Register the module globally
state.SetGlobal(name)
return nil
}
// ModuleInitFunc creates a state initializer that registers multiple modules
func ModuleInitFunc(modules map[string]ModuleFunc) StateInitFunc {
return func(state *luajit.State) error {
for name, moduleFunc := range modules {
if err := RegisterModule(state, name, moduleFunc()); err != nil {
logger.Error("Failed to register module %s: %v", name, err)
return err
}
}
return nil
}
}
// CombineInitFuncs combines multiple state initializer functions into one
func CombineInitFuncs(funcs ...StateInitFunc) StateInitFunc {
return func(state *luajit.State) error {
for _, f := range funcs {
if f != nil {
if err := f(state); err != nil {
return err
}
}
}
return nil
}
}
// RegisterLuaCode registers a Lua code snippet in a state
func RegisterLuaCode(state *luajit.State, code string) error {
return state.DoString(code)
}
// RegisterLuaCodeInitFunc returns a StateInitFunc that registers Lua code
func RegisterLuaCodeInitFunc(code string) StateInitFunc {
return func(state *luajit.State) error {
return RegisterLuaCode(state, code)
}
}
// RegisterLuaModuleInitFunc returns a StateInitFunc that registers a Lua module
func RegisterLuaModuleInitFunc(name string, code string) StateInitFunc {
return func(state *luajit.State) error {
// Create name = {} global
state.NewTable()
state.SetGlobal(name)
// Then run the module code which will populate it
return state.DoString(code)
}
}

View File

@ -1,371 +0,0 @@
package sandbox
import (
"fmt"
"sync"
"github.com/goccy/go-json"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
luaCtx "Moonshark/core/runner/context"
"Moonshark/core/sessions"
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Global bytecode cache to improve performance
var (
sandboxBytecode []byte
bytecodeOnce sync.Once
)
// precompileSandbox compiles the sandbox.lua code to bytecode once
func precompileSandbox() {
tempState := luajit.New()
if tempState == nil {
logger.Error("Failed to create temporary Lua state for bytecode compilation")
return
}
defer tempState.Close()
defer tempState.Cleanup()
var err error
sandboxBytecode, err = tempState.CompileBytecode(sandboxLua, "sandbox.lua")
if err != nil {
logger.Error("Failed to precompile sandbox.lua: %v", err)
} else {
logger.Debug("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(sandboxBytecode))
}
}
// Sandbox provides a secure execution environment for Lua scripts
type Sandbox struct {
modules map[string]any // Custom modules for environment
debug bool // Enable debug output
mu sync.RWMutex // Protects modules
initializers *ModuleInitializers // Module initializers
}
// NewSandbox creates a new sandbox environment
func NewSandbox() *Sandbox {
return &Sandbox{
modules: make(map[string]any, 8), // Pre-allocate with reasonable capacity
debug: false,
initializers: DefaultInitializers(),
}
}
// EnableDebug turns on debug logging
func (s *Sandbox) EnableDebug() {
s.debug = true
}
// debugLog logs a message if debug mode is enabled
func (s *Sandbox) debugLog(format string, args ...interface{}) {
if s.debug {
logger.Debug("Sandbox "+format, args...)
}
}
// debugLogCont logs a continuation message if debug mode is enabled
func (s *Sandbox) debugLogCont(format string, args ...interface{}) {
if s.debug {
logger.DebugCont(format, args...)
}
}
// AddModule adds a module to the sandbox environment
func (s *Sandbox) AddModule(name string, module any) {
s.mu.Lock()
defer s.mu.Unlock()
s.modules[name] = module
s.debugLog("Added module: %s", name)
}
// Setup initializes the sandbox in a Lua state
func (s *Sandbox) Setup(state *luajit.State, stateIndex int) error {
verbose := stateIndex == 0
if verbose {
s.debugLog("Setting up sandbox...")
}
// Initialize modules with the embedded sandbox code
if err := InitializeAll(state, s.initializers); err != nil {
if verbose {
s.debugLog("Failed to initialize sandbox: %v", err)
}
return err
}
// Register custom modules in the global environment
s.mu.RLock()
for name, module := range s.modules {
if verbose {
s.debugLog("Registering module: %s", name)
}
if err := state.PushValue(module); err != nil {
s.mu.RUnlock()
if verbose {
s.debugLog("Failed to register module %s: %v", name, err)
}
return err
}
state.SetGlobal(name)
}
s.mu.RUnlock()
if verbose {
s.debugLogCont("Sandbox setup complete")
}
return nil
}
// Execute runs bytecode in the sandbox
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]any) (any, error) {
// Create a temporary context if we only have a map
if ctx != nil {
tempCtx := &luaCtx.Context{
Values: ctx,
}
return s.OptimizedExecute(state, bytecode, tempCtx)
}
// Just pass nil through if we have no context
return s.OptimizedExecute(state, bytecode, nil)
}
// OptimizedExecute runs bytecode with a fasthttp context if available
func (s *Sandbox) OptimizedExecute(state *luajit.State, bytecode []byte, ctx *luaCtx.Context) (any, error) {
// Use a buffer from the pool for any string operations
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
// Load bytecode
if err := state.LoadBytecode(bytecode, "script"); err != nil {
s.debugLog("Failed to load bytecode: %v", err)
return nil, fmt.Errorf("failed to load script: %w", err)
}
// Prepare context values
var ctxValues map[string]any
if ctx != nil {
ctxValues = ctx.Values
} else {
ctxValues = nil
}
// Initialize session tracking in Lua
if err := state.DoString("__session_data = {}; __session_modified = false"); err != nil {
s.debugLog("Failed to initialize session data: %v", err)
}
// Load session data if available
if ctx != nil && ctx.Session != nil {
// Set session ID in Lua
sessionIDCode := fmt.Sprintf("__session_id = %q", ctx.Session.ID)
if err := state.DoString(sessionIDCode); err != nil {
s.debugLog("Failed to set session ID: %v", err)
}
// Get session data and populate Lua table
state.GetGlobal("__session_data")
if state.IsTable(-1) {
sessionData := ctx.Session.GetAll()
for k, v := range sessionData {
state.PushString(k)
if err := state.PushValue(v); err != nil {
s.debugLog("Failed to push session value %s: %v", k, err)
continue
}
state.SetTable(-3)
}
}
state.Pop(1) // Pop __session_data
}
// Prepare context table
if ctxValues != nil {
state.CreateTable(0, len(ctxValues))
for k, v := range ctxValues {
state.PushString(k)
if err := state.PushValue(v); err != nil {
state.Pop(2) // Pop key and table
s.debugLog("Failed to push context value %s: %v", k, err)
return nil, fmt.Errorf("failed to prepare context: %w", err)
}
state.SetTable(-3)
}
} else {
state.PushNil() // No context
}
// Get execution function
state.GetGlobal("__execute_script")
if !state.IsFunction(-1) {
state.Pop(2) // Pop context and non-function
s.debugLog("__execute_script is not a function")
return nil, fmt.Errorf("sandbox execution function not found")
}
// Stack setup for call: __execute_script, bytecode function, context
state.PushCopy(-3) // bytecode function (copy from -3)
state.PushCopy(-3) // context (copy from -3)
// Clean up duplicate references
state.Remove(-5) // Remove original bytecode function
state.Remove(-4) // Remove original context
// Call with 2 args (function, context), 1 result
if err := state.Call(2, 1); err != nil {
s.debugLog("Execution failed: %v", err)
return nil, fmt.Errorf("script execution failed: %w", err)
}
// Get result
result, err := state.ToValue(-1)
state.Pop(1) // Pop result
// Extract session data if it was modified
if ctx != nil && ctx.Session != nil {
// Check if session was modified
state.GetGlobal("__session_modified")
if state.IsBoolean(-1) && state.ToBoolean(-1) {
ctx.SessionModified = true
// Extract session data
state.GetGlobal("__session_data")
if state.IsTable(-1) {
// Clear existing data and extract new data from Lua
sessionData := make(map[string]any)
// Extract new session data
state.PushNil() // Start iteration
for state.Next(-2) {
// Stack now has key at -2 and value at -1
if state.IsString(-2) {
key := state.ToString(-2)
value, err := state.ToValue(-1)
if err == nil {
sessionData[key] = value
}
}
state.Pop(1) // Pop value, leave key for next iteration
}
// Update session with the new data
for k, v := range sessionData {
if err := ctx.Session.Set(k, v); err != nil {
s.debugLog("Failed to set session value %s: %v", k, err)
}
}
}
state.Pop(1) // Pop __session_data
}
state.Pop(1) // Pop __session_modified
}
// Check for HTTP response
httpResponse, hasResponse := GetHTTPResponse(state)
if hasResponse {
// Add the script result as the response body
httpResponse.Body = result
// Mark session as modified if needed
if ctx != nil && ctx.SessionModified {
httpResponse.SessionModified = true
}
// If we have a fasthttp context, apply the response directly
if ctx != nil && ctx.RequestCtx != nil {
// If session was modified, save it
if ctx.SessionModified && ctx.Session != nil {
// Save session and set cookie if needed
sessions.GlobalSessionManager.SaveSession(ctx.Session)
// Add session cookie to the response
cookieOpts := sessions.GlobalSessionManager.CookieOptions()
cookie := fasthttp.AcquireCookie()
cookie.SetKey(cookieOpts["name"].(string))
cookie.SetValue(ctx.Session.ID)
cookie.SetPath(cookieOpts["path"].(string))
if domain, ok := cookieOpts["domain"].(string); ok && domain != "" {
cookie.SetDomain(domain)
}
if maxAge, ok := cookieOpts["max_age"].(int); ok && maxAge > 0 {
cookie.SetMaxAge(maxAge)
}
cookie.SetSecure(cookieOpts["secure"].(bool))
cookie.SetHTTPOnly(cookieOpts["http_only"].(bool))
// Add to response cookies
httpResponse.Cookies = append(httpResponse.Cookies, cookie)
}
ApplyHTTPResponse(httpResponse, ctx.RequestCtx)
ReleaseResponse(httpResponse)
return nil, nil // No need to return response object
}
return httpResponse, nil
}
// If we have a fasthttp context and the result needs to be written directly
if ctx != nil && ctx.RequestCtx != nil && (result != nil) {
// For direct HTTP responses
switch r := result.(type) {
case string:
ctx.RequestCtx.SetBodyString(r)
case []byte:
ctx.RequestCtx.SetBody(r)
case map[string]any, []any:
// JSON response
ctx.RequestCtx.Response.Header.SetContentType("application/json")
if err := json.NewEncoder(buf).Encode(r); err == nil {
ctx.RequestCtx.SetBody(buf.Bytes())
} else {
ctx.RequestCtx.SetBodyString(fmt.Sprintf("%v", r))
}
default:
// Default string conversion
ctx.RequestCtx.SetBodyString(fmt.Sprintf("%v", r))
}
// Handle session if modified
if ctx.SessionModified && ctx.Session != nil {
// Save session
sessions.GlobalSessionManager.SaveSession(ctx.Session)
// Add session cookie
cookieOpts := sessions.GlobalSessionManager.CookieOptions()
cookie := fasthttp.AcquireCookie()
cookie.SetKey(cookieOpts["name"].(string))
cookie.SetValue(ctx.Session.ID)
cookie.SetPath(cookieOpts["path"].(string))
if domain, ok := cookieOpts["domain"].(string); ok && domain != "" {
cookie.SetDomain(domain)
}
if maxAge, ok := cookieOpts["max_age"].(int); ok && maxAge > 0 {
cookie.SetMaxAge(maxAge)
}
cookie.SetSecure(cookieOpts["secure"].(bool))
cookie.SetHTTPOnly(cookieOpts["http_only"].(bool))
// Add to response
ctx.RequestCtx.Response.Header.SetCookie(cookie)
}
return nil, nil
}
return result, err
}

View File

@ -1,58 +0,0 @@
package sandbox
import (
"crypto/rand"
"encoding/base64"
"Moonshark/core/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// UtilModuleInitFunc returns an initializer for the util module
func UtilModuleInitFunc() func(*luajit.State) error {
return func(state *luajit.State) error {
return RegisterModule(state, "util", UtilModuleFunctions())
}
}
// UtilModuleFunctions returns all functions for the util module
func UtilModuleFunctions() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"generate_token": GenerateToken,
}
}
// GenerateToken creates a cryptographically secure random token
func GenerateToken(s *luajit.State) int {
// Get the length from the Lua arguments (default to 32)
length := 32
if s.GetTop() >= 1 && s.IsNumber(1) {
length = int(s.ToNumber(1))
}
// Enforce minimum length for security
if length < 16 {
length = 16
}
// Generate secure random bytes
tokenBytes := make([]byte, length)
if _, err := rand.Read(tokenBytes); err != nil {
s.PushString("")
logger.Error("Failed to generate secure token: %v", err)
return 1 // Return empty string on error
}
// Encode as base64
token := base64.RawURLEncoding.EncodeToString(tokenBytes)
// Trim to requested length (base64 might be longer)
if len(token) > length {
token = token[:length]
}
// Push the token to the Lua stack
s.PushString(token)
return 1 // One return value
}

View File

@ -9,6 +9,7 @@ import (
"github.com/VictoriaMetrics/fastcache" "github.com/VictoriaMetrics/fastcache"
"github.com/goccy/go-json" "github.com/goccy/go-json"
"github.com/valyala/fasthttp"
) )
const ( const (
@ -75,7 +76,7 @@ func (sm *SessionManager) GetSession(id string) *Session {
// Store back with updated timestamp // Store back with updated timestamp
updatedData, _ := json.Marshal(session) updatedData, _ := json.Marshal(session)
sm.cache.Set([]byte(id), updatedData) // Use updatedData, not data sm.cache.Set([]byte(id), updatedData)
return session return session
} }
@ -141,5 +142,39 @@ func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, ht
sm.cookieMaxAge = maxAge sm.cookieMaxAge = maxAge
} }
// GetSessionFromRequest extracts the session from a request context
func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session {
cookie := ctx.Request.Header.Cookie(sm.cookieName)
if len(cookie) == 0 {
// No session cookie, create a new session
return sm.CreateSession()
}
// Session cookie exists, get the session
return sm.GetSession(string(cookie))
}
// SaveSessionToResponse adds the session cookie to an HTTP response
func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) {
cookie := fasthttp.AcquireCookie()
defer fasthttp.ReleaseCookie(cookie)
sm.mu.RLock()
cookie.SetKey(sm.cookieName)
cookie.SetValue(session.ID)
cookie.SetPath(sm.cookiePath)
cookie.SetHTTPOnly(sm.cookieHTTPOnly)
cookie.SetMaxAge(sm.cookieMaxAge)
if sm.cookieDomain != "" {
cookie.SetDomain(sm.cookieDomain)
}
cookie.SetSecure(sm.cookieSecure)
sm.mu.RUnlock()
ctx.Response.Header.SetCookie(cookie)
}
// GlobalSessionManager is the default session manager instance // GlobalSessionManager is the default session manager instance
var GlobalSessionManager = NewSessionManager() var GlobalSessionManager = NewSessionManager()