Compare commits
2 Commits
5ba0a0abd9
...
5ebcd97662
Author | SHA1 | Date | |
---|---|---|---|
5ebcd97662 | |||
f4b1e5fad7 |
|
@ -184,8 +184,6 @@ func (s *Moonshark) initRunner() error {
|
|||
runnerOpts := []runner.RunnerOption{
|
||||
runner.WithPoolSize(s.Config.Runner.PoolSize),
|
||||
runner.WithLibDirs(s.Config.Dirs.Libs...),
|
||||
runner.WithSessionManager(sessionManager),
|
||||
http.WithCSRFProtection(),
|
||||
}
|
||||
|
||||
// Add debug option conditionally
|
||||
|
|
|
@ -5,14 +5,16 @@ import (
|
|||
"Moonshark/core/utils"
|
||||
"Moonshark/core/utils/logger"
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Error for CSRF validation failure
|
||||
var ErrCSRFValidationFailed = errors.New("CSRF token validation failed")
|
||||
|
||||
// ValidateCSRFToken checks if the CSRF token is valid for a request
|
||||
func ValidateCSRFToken(state *luajit.State, ctx *runner.Context) bool {
|
||||
func ValidateCSRFToken(ctx *runner.Context) bool {
|
||||
// Only validate for form submissions
|
||||
method, ok := ctx.Get("method").(string)
|
||||
if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") {
|
||||
|
@ -33,87 +35,23 @@ func ValidateCSRFToken(state *luajit.State, ctx *runner.Context) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// Get session token
|
||||
state.GetGlobal("session")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
logger.Warning("CSRF validation failed: session module not available")
|
||||
// Get token from session
|
||||
sessionData := ctx.SessionData
|
||||
if sessionData == nil {
|
||||
logger.Warning("CSRF validation failed: no session data")
|
||||
return false
|
||||
}
|
||||
|
||||
state.GetField(-1, "get")
|
||||
if !state.IsFunction(-1) {
|
||||
state.Pop(2)
|
||||
logger.Warning("CSRF validation failed: session.get not available")
|
||||
return false
|
||||
}
|
||||
|
||||
state.PushCopy(-1) // Duplicate function
|
||||
state.PushString("_csrf_token")
|
||||
|
||||
if err := state.Call(1, 1); err != nil {
|
||||
state.Pop(3) // Pop error, function and session table
|
||||
logger.Warning("CSRF validation failed: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(3) // Pop nil, function and session table
|
||||
sessionToken, ok := sessionData["_csrf_token"].(string)
|
||||
if !ok || sessionToken == "" {
|
||||
logger.Warning("CSRF validation failed: no token in session")
|
||||
return false
|
||||
}
|
||||
|
||||
sessionToken := state.ToString(-1)
|
||||
state.Pop(3) // Pop token, function and session table
|
||||
|
||||
// Constant-time comparison to prevent timing attacks
|
||||
return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1
|
||||
}
|
||||
|
||||
// WithCSRFProtection creates a runner option to add CSRF protection
|
||||
func WithCSRFProtection() runner.RunnerOption {
|
||||
return func(r *runner.Runner) {
|
||||
r.AddInitHook(func(state *luajit.State, ctx *runner.Context) error {
|
||||
// Get request method
|
||||
method, ok := ctx.Get("method").(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only validate for form submissions
|
||||
if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for form data
|
||||
form, ok := ctx.Get("form").(map[string]any)
|
||||
if !ok || form == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate CSRF token
|
||||
if !ValidateCSRFToken(state, ctx) {
|
||||
return ErrCSRFValidationFailed
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Error for CSRF validation failure
|
||||
var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"}
|
||||
|
||||
// CSRFError represents a CSRF validation error
|
||||
type CSRFError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *CSRFError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
// HandleCSRFError handles a CSRF validation error
|
||||
func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) {
|
||||
method := string(ctx.Method())
|
||||
|
@ -128,3 +66,39 @@ func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig
|
|||
errorHTML := utils.ForbiddenPage(errorConfig, path, errorMsg)
|
||||
ctx.SetBody([]byte(errorHTML))
|
||||
}
|
||||
|
||||
// GenerateCSRFToken creates a new CSRF token and stores it in the session
|
||||
func GenerateCSRFToken(ctx *runner.Context, length int) (string, error) {
|
||||
if length < 16 {
|
||||
length = 16 // Minimum token length for security
|
||||
}
|
||||
|
||||
// Create secure random token
|
||||
token, err := GenerateSecureToken(length)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Store token in session
|
||||
ctx.SessionData["_csrf_token"] = token
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GetCSRFToken retrieves the current CSRF token or generates a new one
|
||||
func GetCSRFToken(ctx *runner.Context) (string, error) {
|
||||
// Check if token already exists in session
|
||||
if token, ok := ctx.SessionData["_csrf_token"].(string); ok && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Generate new token
|
||||
return GenerateCSRFToken(ctx, 32)
|
||||
}
|
||||
|
||||
// CSRFMiddleware validates CSRF tokens for state-changing requests
|
||||
func CSRFMiddleware(ctx *runner.Context) error {
|
||||
if !ValidateCSRFToken(ctx) {
|
||||
return ErrCSRFValidationFailed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,116 +0,0 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// Maximum form parse size (16MB)
|
||||
const maxFormSize = 16 << 20
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
ErrFormSizeTooLarge = errors.New("form size too large")
|
||||
ErrInvalidFormType = errors.New("invalid form content type")
|
||||
)
|
||||
|
||||
// ParseForm parses a POST request body into a map of values
|
||||
// Supports both application/x-www-form-urlencoded and multipart/form-data content types
|
||||
func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
|
||||
// Only handle POST, PUT, PATCH
|
||||
method := string(ctx.Method())
|
||||
if method != "POST" && method != "PUT" && method != "PATCH" {
|
||||
return make(map[string]any), nil
|
||||
}
|
||||
|
||||
// Check content type
|
||||
contentType := string(ctx.Request.Header.ContentType())
|
||||
if contentType == "" {
|
||||
return make(map[string]any), nil
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
|
||||
// Check for content length to prevent DOS
|
||||
if len(ctx.Request.Body()) > maxFormSize {
|
||||
return nil, ErrFormSizeTooLarge
|
||||
}
|
||||
|
||||
// Handle by content type
|
||||
if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") {
|
||||
return parseURLEncodedForm(ctx)
|
||||
} else if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
return parseMultipartForm(ctx)
|
||||
}
|
||||
|
||||
// Unrecognized content type
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// parseURLEncodedForm handles application/x-www-form-urlencoded forms
|
||||
func parseURLEncodedForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
|
||||
result := make(map[string]any)
|
||||
|
||||
// Process form values directly from PostArgs()
|
||||
ctx.PostArgs().VisitAll(func(key, value []byte) {
|
||||
keyStr := string(key)
|
||||
valStr := string(value)
|
||||
|
||||
// Check if we already have this key
|
||||
if existing, ok := result[keyStr]; ok {
|
||||
// If it's already a slice, append
|
||||
if existingSlice, ok := existing.([]string); ok {
|
||||
result[keyStr] = append(existingSlice, valStr)
|
||||
} else {
|
||||
// Convert to slice and append
|
||||
result[keyStr] = []string{existing.(string), valStr}
|
||||
}
|
||||
} else {
|
||||
// New key
|
||||
result[keyStr] = valStr
|
||||
}
|
||||
})
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// parseMultipartForm handles multipart/form-data forms
|
||||
func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
|
||||
result := make(map[string]any)
|
||||
|
||||
// Parse multipart form
|
||||
form, err := ctx.MultipartForm()
|
||||
if err != nil {
|
||||
if err == multipart.ErrMessageTooLarge || strings.Contains(err.Error(), "too large") {
|
||||
return nil, ErrFormSizeTooLarge
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process form values
|
||||
for key, values := range form.Value {
|
||||
if len(values) == 1 {
|
||||
// Single value
|
||||
result[key] = values[0]
|
||||
} else if len(values) > 1 {
|
||||
// Multiple values - store as string slice
|
||||
strValues := make([]string, len(values))
|
||||
copy(strValues, values)
|
||||
result[key] = strValues
|
||||
}
|
||||
}
|
||||
|
||||
// We don't handle file uploads here - could be extended in the future
|
||||
// if needed to support file uploads to Lua
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Usage:
|
||||
// After parsing the form with ParseForm, you can add it to the context with:
|
||||
// ctx.Set("form", formData)
|
||||
//
|
||||
// This makes the form data accessible in Lua as ctx.form.field_name
|
|
@ -1,43 +0,0 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"Moonshark/core/utils/logger"
|
||||
)
|
||||
|
||||
// StatusColors for different status code ranges
|
||||
const (
|
||||
colorGreen = "\033[32m" // 2xx - Success
|
||||
colorCyan = "\033[36m" // 3xx - Redirection
|
||||
colorYellow = "\033[33m" // 4xx - Client Errors
|
||||
colorRed = "\033[31m" // 5xx - Server Errors
|
||||
colorReset = "\033[0m" // Reset color
|
||||
colorGray = "\033[90m"
|
||||
)
|
||||
|
||||
// LogRequest logs an HTTP request with custom formatting
|
||||
func LogRequest(statusCode int, method, path string, duration time.Duration) {
|
||||
statusColor := getStatusColor(statusCode)
|
||||
|
||||
// Use the logger's raw message writer to bypass the standard format
|
||||
logger.LogRaw("%s%s%s %s%d %s%s %s %s(%v)%s",
|
||||
colorGray, time.Now().Format(logger.TimeFormat()), colorReset,
|
||||
statusColor, statusCode, method, colorReset, path, colorGray, duration, colorReset)
|
||||
}
|
||||
|
||||
// getStatusColor returns the ANSI color code for a status code
|
||||
func getStatusColor(code int) string {
|
||||
switch {
|
||||
case code >= 200 && code < 300:
|
||||
return colorGreen
|
||||
case code >= 300 && code < 400:
|
||||
return colorCyan
|
||||
case code >= 400 && code < 500:
|
||||
return colorYellow
|
||||
case code >= 500:
|
||||
return colorRed
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// QueryToLua converts HTTP query parameters to a map that can be used with LuaJIT.
|
||||
// Single value parameters are stored as strings.
|
||||
// Multi-value parameters are converted to []any arrays.
|
||||
func QueryToLua(ctx *fasthttp.RequestCtx) map[string]any {
|
||||
result := make(map[string]any)
|
||||
|
||||
// Use a map to track keys that have multiple values
|
||||
multiValueKeys := make(map[string]bool)
|
||||
|
||||
// Process all query args
|
||||
ctx.QueryArgs().VisitAll(func(key, value []byte) {
|
||||
keyStr := string(key)
|
||||
valStr := string(value)
|
||||
|
||||
if _, exists := result[keyStr]; exists {
|
||||
// This key already exists, convert to array if not already
|
||||
if !multiValueKeys[keyStr] {
|
||||
// First duplicate, convert existing value to array
|
||||
multiValueKeys[keyStr] = true
|
||||
result[keyStr] = []any{result[keyStr], valStr}
|
||||
} else {
|
||||
// Already an array, append
|
||||
result[keyStr] = append(result[keyStr].([]any), valStr)
|
||||
}
|
||||
} else {
|
||||
// New key
|
||||
result[keyStr] = valStr
|
||||
}
|
||||
})
|
||||
|
||||
// If we don't have any query parameters, return empty map
|
||||
if len(result) == 0 {
|
||||
return make(map[string]any)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
|
@ -2,19 +2,17 @@ package http
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"Moonshark/core/metadata"
|
||||
"Moonshark/core/routers"
|
||||
"Moonshark/core/runner"
|
||||
"Moonshark/core/runner/sandbox"
|
||||
"Moonshark/core/sessions"
|
||||
"Moonshark/core/utils"
|
||||
"Moonshark/core/utils/config"
|
||||
"Moonshark/core/utils/logger"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
|
@ -27,12 +25,14 @@ type Server struct {
|
|||
loggingEnabled bool
|
||||
debugMode bool
|
||||
config *config.Config
|
||||
sessionManager *sessions.SessionManager
|
||||
errorConfig utils.ErrorPageConfig
|
||||
}
|
||||
|
||||
// New creates a new HTTP server with optimized connection settings
|
||||
func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runner *runner.Runner,
|
||||
loggingEnabled bool, debugMode bool, overrideDir string, config *config.Config) *Server {
|
||||
func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter,
|
||||
runner *runner.Runner, loggingEnabled bool, debugMode bool,
|
||||
overrideDir string, config *config.Config) *Server {
|
||||
|
||||
server := &Server{
|
||||
luaRouter: luaRouter,
|
||||
|
@ -41,6 +41,7 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runne
|
|||
loggingEnabled: loggingEnabled,
|
||||
debugMode: debugMode,
|
||||
config: config,
|
||||
sessionManager: sessions.GlobalSessionManager,
|
||||
errorConfig: utils.ErrorPageConfig{
|
||||
OverrideDir: overrideDir,
|
||||
DebugMode: debugMode,
|
||||
|
@ -53,7 +54,7 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runne
|
|||
Name: "Moonshark/" + metadata.Version,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
MaxRequestBodySize: 16 << 20, // 16MB - consistent with Forms.go
|
||||
MaxRequestBodySize: 16 << 20, // 16MB
|
||||
DisableKeepalive: false,
|
||||
TCPKeepalive: true,
|
||||
TCPKeepalivePeriod: 60 * time.Second,
|
||||
|
@ -97,7 +98,7 @@ func (s *Server) handleRequest(ctx *fasthttp.RequestCtx) {
|
|||
// Process the request
|
||||
s.processRequest(ctx)
|
||||
|
||||
// Log the request with our custom format
|
||||
// Log the request
|
||||
if s.loggingEnabled {
|
||||
duration := time.Since(start)
|
||||
LogRequest(ctx.Response.StatusCode(), method, path, duration)
|
||||
|
@ -151,47 +152,25 @@ func (s *Server) processRequest(ctx *fasthttp.RequestCtx) {
|
|||
ctx.SetBody([]byte(utils.NotFoundPage(s.errorConfig, path)))
|
||||
}
|
||||
|
||||
// HandleMethodNotAllowed responds with a 405 Method Not Allowed error
|
||||
func HandleMethodNotAllowed(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) {
|
||||
path := string(ctx.Path())
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetStatusCode(fasthttp.StatusMethodNotAllowed)
|
||||
ctx.SetBody([]byte(utils.MethodNotAllowedPage(errorConfig, path)))
|
||||
}
|
||||
|
||||
// handleLuaRoute executes a Lua route
|
||||
func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scriptPath string, params *routers.Params) {
|
||||
luaCtx := runner.NewHTTPContext(ctx) // Use NewHTTPContext instead of NewContext
|
||||
// Create context for Lua execution
|
||||
luaCtx := runner.NewHTTPContext(ctx)
|
||||
defer luaCtx.Release()
|
||||
|
||||
method := string(ctx.Method())
|
||||
path := string(ctx.Path())
|
||||
host := string(ctx.Host())
|
||||
|
||||
// Set up context
|
||||
// Set up additional context values
|
||||
luaCtx.Set("method", method)
|
||||
luaCtx.Set("path", path)
|
||||
luaCtx.Set("host", host)
|
||||
|
||||
// Headers
|
||||
headerMap := make(map[string]any)
|
||||
ctx.Request.Header.VisitAll(func(key, value []byte) {
|
||||
headerMap[string(key)] = string(value)
|
||||
})
|
||||
luaCtx.Set("headers", headerMap)
|
||||
|
||||
// Cookies
|
||||
cookieMap := make(map[string]any)
|
||||
ctx.Request.Header.VisitAllCookie(func(key, value []byte) {
|
||||
cookieMap[string(key)] = string(value)
|
||||
})
|
||||
if len(cookieMap) > 0 {
|
||||
luaCtx.Set("cookies", cookieMap)
|
||||
luaCtx.Set("_request_cookies", cookieMap) // For backward compatibility
|
||||
} else {
|
||||
luaCtx.Set("cookies", make(map[string]any))
|
||||
luaCtx.Set("_request_cookies", make(map[string]any))
|
||||
}
|
||||
// Initialize session
|
||||
session := s.sessionManager.GetSessionFromRequest(ctx)
|
||||
luaCtx.SessionID = session.ID
|
||||
luaCtx.SessionData = session.GetAll()
|
||||
|
||||
// URL parameters
|
||||
if params.Count > 0 {
|
||||
|
@ -204,11 +183,7 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
|||
luaCtx.Set("params", make(map[string]any))
|
||||
}
|
||||
|
||||
// Query parameters
|
||||
queryMap := QueryToLua(ctx)
|
||||
luaCtx.Set("query", queryMap)
|
||||
|
||||
// Form data
|
||||
// Parse form data for POST/PUT/PATCH requests
|
||||
if method == "POST" || method == "PUT" || method == "PATCH" {
|
||||
formData, err := ParseForm(ctx)
|
||||
if err == nil && len(formData) > 0 {
|
||||
|
@ -223,19 +198,26 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
|||
luaCtx.Set("form", make(map[string]any))
|
||||
}
|
||||
|
||||
// Execute Lua script
|
||||
result, err := s.luaRunner.Run(bytecode, luaCtx, scriptPath)
|
||||
// CSRF middleware for state-changing requests
|
||||
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
|
||||
if !ValidateCSRFToken(luaCtx) {
|
||||
HandleCSRFError(ctx, s.errorConfig)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Special handling for CSRF error
|
||||
// Execute Lua script
|
||||
response, err := s.luaRunner.Run(bytecode, luaCtx, scriptPath)
|
||||
if err != nil {
|
||||
if csrfErr, ok := err.(*CSRFError); ok {
|
||||
logger.Warning("CSRF error executing Lua route: %v", csrfErr)
|
||||
logger.Error("Error executing Lua route: %v", err)
|
||||
|
||||
// Special handling for specific errors
|
||||
if errors.Is(err, ErrCSRFValidationFailed) {
|
||||
HandleCSRFError(ctx, s.errorConfig)
|
||||
return
|
||||
}
|
||||
|
||||
// Normal error handling
|
||||
logger.Error("Error executing Lua route: %v", err)
|
||||
// General error handling
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
errorHTML := utils.InternalErrorPage(s.errorConfig, path, err.Error())
|
||||
|
@ -243,104 +225,21 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
|||
return
|
||||
}
|
||||
|
||||
// If we got a non-nil result, write it to the response
|
||||
if result != nil {
|
||||
writeResponse(ctx, result)
|
||||
// Save session if modified
|
||||
if response.SessionModified {
|
||||
// Update session data
|
||||
for k, v := range response.SessionData {
|
||||
session.Set(k, v)
|
||||
}
|
||||
s.sessionManager.SaveSession(session)
|
||||
s.sessionManager.ApplySessionCookie(ctx, session)
|
||||
}
|
||||
|
||||
// Content types for responses
|
||||
const (
|
||||
contentTypeJSON = "application/json"
|
||||
contentTypePlain = "text/plain"
|
||||
)
|
||||
// Apply response to HTTP context
|
||||
runner.ApplyResponse(response, ctx)
|
||||
|
||||
// writeResponse writes the Lua result to the HTTP response
|
||||
func writeResponse(ctx *fasthttp.RequestCtx, result any) {
|
||||
if result == nil {
|
||||
ctx.SetStatusCode(fasthttp.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
// First check the raw type of the result for strong type identification
|
||||
// Sometimes type assertions don't work as expected with interface values
|
||||
resultType := fmt.Sprintf("%T", result)
|
||||
|
||||
// Strong check for HTTP response
|
||||
if strings.Contains(resultType, "HTTPResponse") || strings.Contains(resultType, "sandbox.HTTPResponse") {
|
||||
httpResp, ok := result.(*sandbox.HTTPResponse)
|
||||
if ok {
|
||||
defer sandbox.ReleaseResponse(httpResp)
|
||||
|
||||
// Set response headers
|
||||
for name, value := range httpResp.Headers {
|
||||
ctx.Response.Header.Set(name, value)
|
||||
}
|
||||
|
||||
// Set cookies
|
||||
for _, cookie := range httpResp.Cookies {
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
}
|
||||
|
||||
// Set status code
|
||||
ctx.SetStatusCode(httpResp.Status)
|
||||
|
||||
// Process the body based on its type
|
||||
if httpResp.Body == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Continue with the body only
|
||||
result = httpResp.Body
|
||||
} else {
|
||||
// We identified it as HTTPResponse but couldn't convert it
|
||||
// This is a programming error
|
||||
logger.Error("Found HTTPResponse type but failed to convert: %v", resultType)
|
||||
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check if it's a map (table) or array - return as JSON
|
||||
isJSON := false
|
||||
switch result.(type) {
|
||||
case map[string]any, []any, []float64, []string, []int:
|
||||
isJSON = true
|
||||
}
|
||||
|
||||
if isJSON {
|
||||
setContentTypeIfMissing(ctx, contentTypeJSON)
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
logger.Error("Failed to marshal response: %v", err)
|
||||
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
ctx.SetBody(data)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle string and byte slice cases directly
|
||||
switch r := result.(type) {
|
||||
case string:
|
||||
setContentTypeIfMissing(ctx, contentTypePlain)
|
||||
ctx.SetBodyString(r)
|
||||
return
|
||||
case []byte:
|
||||
setContentTypeIfMissing(ctx, contentTypePlain)
|
||||
ctx.SetBody(r)
|
||||
return
|
||||
}
|
||||
|
||||
// If we reach here, it's an unexpected type - convert to string as a last resort
|
||||
setContentTypeIfMissing(ctx, contentTypePlain)
|
||||
ctx.SetBodyString(fmt.Sprintf("%v", result))
|
||||
}
|
||||
|
||||
func setContentTypeIfMissing(ctx *fasthttp.RequestCtx, contentType string) {
|
||||
if len(ctx.Response.Header.ContentType()) == 0 {
|
||||
ctx.SetContentType(contentType)
|
||||
}
|
||||
// Release the response when done
|
||||
runner.ReleaseResponse(response)
|
||||
}
|
||||
|
||||
// handleDebugStats displays debug statistics
|
||||
|
@ -350,12 +249,14 @@ func (s *Server) handleDebugStats(ctx *fasthttp.RequestCtx) {
|
|||
|
||||
// Add component stats
|
||||
routeCount, bytecodeBytes := s.luaRouter.GetRouteStats()
|
||||
moduleCount := s.luaRunner.GetModuleCount()
|
||||
//stateCount := s.luaRunner.GetStateCount()
|
||||
//activeStates := s.luaRunner.GetActiveStateCount()
|
||||
|
||||
stats.Components = utils.ComponentStats{
|
||||
RouteCount: routeCount,
|
||||
BytecodeBytes: bytecodeBytes,
|
||||
ModuleCount: moduleCount,
|
||||
//StatesCount: stateCount,
|
||||
//ActiveStates: activeStates,
|
||||
}
|
||||
|
||||
// Generate HTML page
|
||||
|
|
206
core/http/Utils.go
Normal file
206
core/http/Utils.go
Normal file
|
@ -0,0 +1,206 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"Moonshark/core/utils/logger"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// LogRequest logs an HTTP request with its status code and duration
|
||||
func LogRequest(statusCode int, method, path string, duration time.Duration) {
|
||||
var statusColor, resetColor, methodColor string
|
||||
|
||||
// Status code colors
|
||||
if statusCode >= 200 && statusCode < 300 {
|
||||
statusColor = "\u001b[32m" // Green for 2xx
|
||||
} else if statusCode >= 300 && statusCode < 400 {
|
||||
statusColor = "\u001b[36m" // Cyan for 3xx
|
||||
} else if statusCode >= 400 && statusCode < 500 {
|
||||
statusColor = "\u001b[33m" // Yellow for 4xx
|
||||
} else {
|
||||
statusColor = "\u001b[31m" // Red for 5xx and others
|
||||
}
|
||||
|
||||
// Method colors
|
||||
switch method {
|
||||
case "GET":
|
||||
methodColor = "\u001b[32m" // Green
|
||||
case "POST":
|
||||
methodColor = "\u001b[34m" // Blue
|
||||
case "PUT":
|
||||
methodColor = "\u001b[33m" // Yellow
|
||||
case "DELETE":
|
||||
methodColor = "\u001b[31m" // Red
|
||||
default:
|
||||
methodColor = "\u001b[35m" // Magenta for others
|
||||
}
|
||||
|
||||
resetColor = "\u001b[0m"
|
||||
|
||||
// Format duration
|
||||
var durationStr string
|
||||
if duration.Milliseconds() < 1 {
|
||||
durationStr = fmt.Sprintf("%.2fµs", float64(duration.Microseconds()))
|
||||
} else if duration.Milliseconds() < 1000 {
|
||||
durationStr = fmt.Sprintf("%.2fms", float64(duration.Microseconds())/1000)
|
||||
} else {
|
||||
durationStr = fmt.Sprintf("%.2fs", duration.Seconds())
|
||||
}
|
||||
|
||||
// Log with colors
|
||||
logger.Server("%s%d%s %s%s%s %s %s",
|
||||
statusColor, statusCode, resetColor,
|
||||
methodColor, method, resetColor,
|
||||
path, durationStr)
|
||||
}
|
||||
|
||||
// QueryToLua converts HTTP query args to a Lua-friendly map
|
||||
func QueryToLua(ctx *fasthttp.RequestCtx) map[string]any {
|
||||
queryMap := make(map[string]any)
|
||||
|
||||
// Visit all query parameters
|
||||
ctx.QueryArgs().VisitAll(func(key, value []byte) {
|
||||
// Convert to string
|
||||
k := string(key)
|
||||
v := string(value)
|
||||
|
||||
// Check if this key already exists as an array
|
||||
if existing, ok := queryMap[k]; ok {
|
||||
// If it's already an array, append to it
|
||||
if arr, ok := existing.([]string); ok {
|
||||
queryMap[k] = append(arr, v)
|
||||
} else if str, ok := existing.(string); ok {
|
||||
// Convert existing string to array and append new value
|
||||
queryMap[k] = []string{str, v}
|
||||
}
|
||||
} else {
|
||||
// New key, store as string
|
||||
queryMap[k] = v
|
||||
}
|
||||
})
|
||||
|
||||
return queryMap
|
||||
}
|
||||
|
||||
// ParseForm extracts form data from a request
|
||||
func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
|
||||
formData := make(map[string]any)
|
||||
|
||||
// Check if multipart form
|
||||
if strings.Contains(string(ctx.Request.Header.ContentType()), "multipart/form-data") {
|
||||
return parseMultipartForm(ctx)
|
||||
}
|
||||
|
||||
// Regular form
|
||||
ctx.PostArgs().VisitAll(func(key, value []byte) {
|
||||
k := string(key)
|
||||
v := string(value)
|
||||
|
||||
// Check if this key already exists
|
||||
if existing, ok := formData[k]; ok {
|
||||
// If it's already an array, append to it
|
||||
if arr, ok := existing.([]string); ok {
|
||||
formData[k] = append(arr, v)
|
||||
} else if str, ok := existing.(string); ok {
|
||||
// Convert existing string to array and append new value
|
||||
formData[k] = []string{str, v}
|
||||
}
|
||||
} else {
|
||||
// New key, store as string
|
||||
formData[k] = v
|
||||
}
|
||||
})
|
||||
|
||||
return formData, nil
|
||||
}
|
||||
|
||||
// parseMultipartForm handles multipart/form-data requests
|
||||
func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
|
||||
formData := make(map[string]any)
|
||||
|
||||
// Parse multipart form
|
||||
form, err := ctx.MultipartForm()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process form values
|
||||
for key, values := range form.Value {
|
||||
if len(values) == 1 {
|
||||
formData[key] = values[0]
|
||||
} else if len(values) > 1 {
|
||||
formData[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
// Process files (store file info, not the content)
|
||||
if len(form.File) > 0 {
|
||||
files := make(map[string]any)
|
||||
|
||||
for fieldName, fileHeaders := range form.File {
|
||||
if len(fileHeaders) == 1 {
|
||||
files[fieldName] = fileInfoToMap(fileHeaders[0])
|
||||
} else if len(fileHeaders) > 1 {
|
||||
fileInfos := make([]map[string]any, 0, len(fileHeaders))
|
||||
for _, fh := range fileHeaders {
|
||||
fileInfos = append(fileInfos, fileInfoToMap(fh))
|
||||
}
|
||||
files[fieldName] = fileInfos
|
||||
}
|
||||
}
|
||||
|
||||
formData["_files"] = files
|
||||
}
|
||||
|
||||
return formData, nil
|
||||
}
|
||||
|
||||
// fileInfoToMap converts a FileHeader to a map for Lua
|
||||
func fileInfoToMap(fh *multipart.FileHeader) map[string]any {
|
||||
return map[string]any{
|
||||
"filename": fh.Filename,
|
||||
"size": fh.Size,
|
||||
"mimetype": getMimeType(fh),
|
||||
}
|
||||
}
|
||||
|
||||
// getMimeType gets the mime type from a file header
|
||||
func getMimeType(fh *multipart.FileHeader) string {
|
||||
if fh.Header != nil {
|
||||
contentType := fh.Header.Get("Content-Type")
|
||||
if contentType != "" {
|
||||
return contentType
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to basic type detection from filename
|
||||
if strings.HasSuffix(fh.Filename, ".pdf") {
|
||||
return "application/pdf"
|
||||
} else if strings.HasSuffix(fh.Filename, ".png") {
|
||||
return "image/png"
|
||||
} else if strings.HasSuffix(fh.Filename, ".jpg") || strings.HasSuffix(fh.Filename, ".jpeg") {
|
||||
return "image/jpeg"
|
||||
} else if strings.HasSuffix(fh.Filename, ".gif") {
|
||||
return "image/gif"
|
||||
} else if strings.HasSuffix(fh.Filename, ".svg") {
|
||||
return "image/svg+xml"
|
||||
}
|
||||
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
// GenerateSecureToken creates a cryptographically secure random token
|
||||
func GenerateSecureToken(length int) (string, error) {
|
||||
b := make([]byte, length)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b)[:length], nil
|
||||
}
|
|
@ -15,6 +15,10 @@ type Context struct {
|
|||
// FastHTTP context if this was created from an HTTP request
|
||||
RequestCtx *fasthttp.RequestCtx
|
||||
|
||||
// Session information
|
||||
SessionID string
|
||||
SessionData map[string]any
|
||||
|
||||
// Buffer for efficient string operations
|
||||
buffer *bytebufferpool.ByteBuffer
|
||||
}
|
||||
|
@ -24,6 +28,7 @@ var contextPool = sync.Pool{
|
|||
New: func() any {
|
||||
return &Context{
|
||||
Values: make(map[string]any, 16),
|
||||
SessionData: make(map[string]any, 8),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -37,6 +42,44 @@ func NewContext() *Context {
|
|||
func NewHTTPContext(requestCtx *fasthttp.RequestCtx) *Context {
|
||||
ctx := NewContext()
|
||||
ctx.RequestCtx = requestCtx
|
||||
|
||||
// Extract common HTTP values that Lua might need
|
||||
if requestCtx != nil {
|
||||
ctx.Values["_request_method"] = string(requestCtx.Method())
|
||||
ctx.Values["_request_path"] = string(requestCtx.Path())
|
||||
ctx.Values["_request_url"] = string(requestCtx.RequestURI())
|
||||
|
||||
// Extract cookies
|
||||
cookies := make(map[string]any)
|
||||
requestCtx.Request.Header.VisitAllCookie(func(key, value []byte) {
|
||||
cookies[string(key)] = string(value)
|
||||
})
|
||||
ctx.Values["_request_cookies"] = cookies
|
||||
|
||||
// Extract query params
|
||||
query := make(map[string]any)
|
||||
requestCtx.QueryArgs().VisitAll(func(key, value []byte) {
|
||||
query[string(key)] = string(value)
|
||||
})
|
||||
ctx.Values["_request_query"] = query
|
||||
|
||||
// Extract form data if present
|
||||
if requestCtx.IsPost() || requestCtx.IsPut() {
|
||||
form := make(map[string]any)
|
||||
requestCtx.PostArgs().VisitAll(func(key, value []byte) {
|
||||
form[string(key)] = string(value)
|
||||
})
|
||||
ctx.Values["_request_form"] = form
|
||||
}
|
||||
|
||||
// Extract headers
|
||||
headers := make(map[string]any)
|
||||
requestCtx.Request.Header.VisitAll(func(key, value []byte) {
|
||||
headers[string(key)] = string(value)
|
||||
})
|
||||
ctx.Values["_request_headers"] = headers
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
|
@ -47,6 +90,13 @@ func (c *Context) Release() {
|
|||
delete(c.Values, k)
|
||||
}
|
||||
|
||||
for k := range c.SessionData {
|
||||
delete(c.SessionData, k)
|
||||
}
|
||||
|
||||
// Reset session info
|
||||
c.SessionID = ""
|
||||
|
||||
// Reset request context
|
||||
c.RequestCtx = nil
|
||||
|
||||
|
@ -77,13 +127,12 @@ func (c *Context) Get(key string) any {
|
|||
return c.Values[key]
|
||||
}
|
||||
|
||||
// Contains checks if a key exists in the context
|
||||
func (c *Context) Contains(key string) bool {
|
||||
_, exists := c.Values[key]
|
||||
return exists
|
||||
// SetSession sets a session data value
|
||||
func (c *Context) SetSession(key string, value any) {
|
||||
c.SessionData[key] = value
|
||||
}
|
||||
|
||||
// Delete removes a value from the context
|
||||
func (c *Context) Delete(key string) {
|
||||
delete(c.Values, key)
|
||||
// GetSession retrieves a session data value
|
||||
func (c *Context) GetSession(key string) any {
|
||||
return c.SessionData[key]
|
||||
}
|
||||
|
|
|
@ -1,262 +0,0 @@
|
|||
package runner
|
||||
|
||||
import (
|
||||
"Moonshark/core/runner/sandbox"
|
||||
"Moonshark/core/utils/logger"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// CoreModuleRegistry manages the initialization and reloading of core modules
|
||||
type CoreModuleRegistry struct {
|
||||
modules map[string]sandbox.StateInitFunc // Module initializers
|
||||
initOrder []string // Explicit initialization order
|
||||
dependencies map[string][]string // Module dependencies
|
||||
initializedFlag map[string]bool // Track which modules are initialized
|
||||
mu sync.RWMutex
|
||||
debug bool
|
||||
}
|
||||
|
||||
// NewCoreModuleRegistry creates a new core module registry
|
||||
func NewCoreModuleRegistry() *CoreModuleRegistry {
|
||||
return &CoreModuleRegistry{
|
||||
modules: make(map[string]sandbox.StateInitFunc),
|
||||
initOrder: []string{},
|
||||
dependencies: make(map[string][]string),
|
||||
initializedFlag: make(map[string]bool),
|
||||
debug: false,
|
||||
}
|
||||
}
|
||||
|
||||
// EnableDebug turns on debug logging
|
||||
func (r *CoreModuleRegistry) EnableDebug() {
|
||||
r.debug = true
|
||||
}
|
||||
|
||||
// debugLog prints debug messages if enabled
|
||||
func (r *CoreModuleRegistry) debugLog(format string, args ...interface{}) {
|
||||
if r.debug {
|
||||
logger.Debug("CoreRegistry "+format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a module to the registry
|
||||
func (r *CoreModuleRegistry) Register(name string, initFunc sandbox.StateInitFunc) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.modules[name] = initFunc
|
||||
|
||||
// Add to initialization order if not already there
|
||||
for _, n := range r.initOrder {
|
||||
if n == name {
|
||||
return // Already registered
|
||||
}
|
||||
}
|
||||
|
||||
r.initOrder = append(r.initOrder, name)
|
||||
r.debugLog("registered module %s", name)
|
||||
}
|
||||
|
||||
// RegisterWithDependencies registers a module with explicit dependencies
|
||||
func (r *CoreModuleRegistry) RegisterWithDependencies(name string, initFunc sandbox.StateInitFunc, dependencies []string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.modules[name] = initFunc
|
||||
r.dependencies[name] = dependencies
|
||||
|
||||
// Add to initialization order if not already there
|
||||
for _, n := range r.initOrder {
|
||||
if n == name {
|
||||
return // Already registered
|
||||
}
|
||||
}
|
||||
|
||||
r.initOrder = append(r.initOrder, name)
|
||||
r.debugLog("registered module %s with dependencies: %v", name, dependencies)
|
||||
}
|
||||
|
||||
// SetInitOrder sets explicit initialization order
|
||||
func (r *CoreModuleRegistry) SetInitOrder(order []string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Create new init order
|
||||
newOrder := make([]string, 0, len(order))
|
||||
|
||||
// First add all known modules that are in the specified order
|
||||
for _, name := range order {
|
||||
if _, exists := r.modules[name]; exists && !contains(newOrder, name) {
|
||||
newOrder = append(newOrder, name)
|
||||
}
|
||||
}
|
||||
|
||||
// Then add any modules not in the specified order
|
||||
for name := range r.modules {
|
||||
if !contains(newOrder, name) {
|
||||
newOrder = append(newOrder, name)
|
||||
}
|
||||
}
|
||||
|
||||
r.initOrder = newOrder
|
||||
r.debugLog("Set initialization order: %v", r.initOrder)
|
||||
}
|
||||
|
||||
// Initialize initializes all registered modules
|
||||
func (r *CoreModuleRegistry) Initialize(state *luajit.State, stateIndex int) error {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
verbose := stateIndex == 0
|
||||
if verbose {
|
||||
r.debugLog("initializing %d modules...", len(r.initOrder))
|
||||
}
|
||||
|
||||
// Clear initialization flags
|
||||
r.initializedFlag = make(map[string]bool)
|
||||
|
||||
// Initialize modules in order, respecting dependencies
|
||||
for _, name := range r.initOrder {
|
||||
if err := r.initializeModule(state, name, []string{}, verbose); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if verbose {
|
||||
r.debugLog("All modules initialized successfully")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// initializeModule initializes a module and its dependencies
|
||||
func (r *CoreModuleRegistry) initializeModule(state *luajit.State, name string,
|
||||
initStack []string, verbose bool) error {
|
||||
// Check if already initialized
|
||||
if r.initializedFlag[name] {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for circular dependencies
|
||||
for _, n := range initStack {
|
||||
if n == name {
|
||||
return fmt.Errorf("circular dependency detected: %s -> %s",
|
||||
strings.Join(initStack, " -> "), name)
|
||||
}
|
||||
}
|
||||
|
||||
// Get init function
|
||||
initFunc, ok := r.modules[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("module not found: %s", name)
|
||||
}
|
||||
|
||||
// Initialize dependencies first
|
||||
deps := r.dependencies[name]
|
||||
if len(deps) > 0 {
|
||||
newStack := append(initStack, name)
|
||||
for _, dep := range deps {
|
||||
if err := r.initializeModule(state, dep, newStack, verbose); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err := initFunc(state)
|
||||
if err != nil {
|
||||
// Always log failures regardless of verbose setting
|
||||
r.debugLog("Initializing module %s... failure: %v", name, err)
|
||||
return fmt.Errorf("failed to initialize module %s: %w", name, err)
|
||||
}
|
||||
|
||||
r.initializedFlag[name] = true
|
||||
|
||||
if verbose {
|
||||
r.debugLog("Initializing module %s... success", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitializeModule initializes a specific module
|
||||
func (r *CoreModuleRegistry) InitializeModule(state *luajit.State, name string) error {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
// Clear initialization flag for this module
|
||||
r.initializedFlag[name] = false
|
||||
|
||||
// Always use verbose logging for explicit module initialization
|
||||
return r.initializeModule(state, name, []string{}, true)
|
||||
}
|
||||
|
||||
// MatchModuleName checks if a file path corresponds to a registered module
|
||||
func (r *CoreModuleRegistry) MatchModuleName(modName string) (string, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
// Exact match
|
||||
if _, ok := r.modules[modName]; ok {
|
||||
return modName, true
|
||||
}
|
||||
|
||||
// Check if the module name ends with a registered module
|
||||
for name := range r.modules {
|
||||
if strings.HasSuffix(modName, "."+name) {
|
||||
return name, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Global registry instance
|
||||
var GlobalRegistry = NewCoreModuleRegistry()
|
||||
|
||||
// Initialize global registry with core modules
|
||||
func init() {
|
||||
GlobalRegistry.EnableDebug() // Enable debugging by default
|
||||
logger.Debug("[ModuleRegistry] Registering core modules...")
|
||||
|
||||
// Register core modules
|
||||
GlobalRegistry.Register("util", func(state *luajit.State) error {
|
||||
return sandbox.UtilModuleInitFunc()(state)
|
||||
})
|
||||
|
||||
GlobalRegistry.Register("http", func(state *luajit.State) error {
|
||||
return sandbox.HTTPModuleInitFunc()(state)
|
||||
})
|
||||
|
||||
// Set explicit initialization order
|
||||
GlobalRegistry.SetInitOrder([]string{
|
||||
"util", // First: core utilities
|
||||
"http", // Second: HTTP functionality
|
||||
"session", // Third: Session functionality
|
||||
"csrf", // Fourth: CSRF protection
|
||||
})
|
||||
|
||||
logger.Debug("Core modules registered successfully")
|
||||
}
|
||||
|
||||
// RegisterCoreModule registers a core module with the global registry
|
||||
func RegisterCoreModule(name string, initFunc sandbox.StateInitFunc) {
|
||||
GlobalRegistry.Register(name, initFunc)
|
||||
}
|
||||
|
||||
// RegisterCoreModuleWithDependencies registers a module with dependencies
|
||||
func RegisterCoreModuleWithDependencies(name string, initFunc sandbox.StateInitFunc, dependencies []string) {
|
||||
GlobalRegistry.RegisterWithDependencies(name, initFunc, dependencies)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func contains(slice []string, item string) bool {
|
||||
for _, s := range slice {
|
||||
if s == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
61
core/runner/Embed.go
Normal file
61
core/runner/Embed.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package runner
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"Moonshark/core/utils/logger"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
//go:embed sandbox.lua
|
||||
var sandboxLuaCode string
|
||||
|
||||
// Global bytecode cache to improve performance
|
||||
var (
|
||||
sandboxBytecode atomic.Pointer[[]byte]
|
||||
bytecodeOnce sync.Once
|
||||
)
|
||||
|
||||
// precompileSandboxCode compiles the sandbox.lua code to bytecode once
|
||||
func precompileSandboxCode() {
|
||||
// Create temporary state for compilation
|
||||
tempState := luajit.New()
|
||||
if tempState == nil {
|
||||
logger.Error("Failed to create temp Lua state for bytecode compilation")
|
||||
return
|
||||
}
|
||||
defer tempState.Close()
|
||||
defer tempState.Cleanup()
|
||||
|
||||
code, err := tempState.CompileBytecode(sandboxLuaCode, "sandbox.lua")
|
||||
if err != nil {
|
||||
logger.Error("Failed to compile sandbox code: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
bytecode := make([]byte, len(code))
|
||||
copy(bytecode, code)
|
||||
sandboxBytecode.Store(&bytecode)
|
||||
|
||||
logger.Debug("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(code))
|
||||
}
|
||||
|
||||
// loadSandboxIntoState loads the sandbox code into a Lua state
|
||||
func loadSandboxIntoState(state *luajit.State) error {
|
||||
// Initialize bytecode once
|
||||
bytecodeOnce.Do(precompileSandboxCode)
|
||||
|
||||
// Use precompiled bytecode if available
|
||||
bytecode := sandboxBytecode.Load()
|
||||
if bytecode != nil && len(*bytecode) > 0 {
|
||||
logger.Debug("Loading sandbox.lua from precompiled bytecode")
|
||||
return state.LoadAndRunBytecode(*bytecode, "sandbox.lua")
|
||||
}
|
||||
|
||||
// Fallback to direct execution
|
||||
logger.Warning("Using non-precompiled sandbox.lua (bytecode compilation failed)")
|
||||
return state.DoString(sandboxLuaCode)
|
||||
}
|
|
@ -1,12 +1,13 @@
|
|||
package sandbox
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
|
@ -18,27 +19,8 @@ import (
|
|||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// HTTPResponse represents an HTTP response from Lua
|
||||
type HTTPResponse struct {
|
||||
Status int `json:"status"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Body any `json:"body"`
|
||||
Cookies []*fasthttp.Cookie `json:"-"`
|
||||
}
|
||||
|
||||
// Response pool to reduce allocations
|
||||
var responsePool = sync.Pool{
|
||||
New: func() any {
|
||||
return &HTTPResponse{
|
||||
Status: 200,
|
||||
Headers: make(map[string]string, 8),
|
||||
Cookies: make([]*fasthttp.Cookie, 0, 4),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// Default HTTP client with sensible timeout
|
||||
var defaultFastClient fasthttp.Client = fasthttp.Client{
|
||||
var defaultFastClient = fasthttp.Client{
|
||||
MaxConnsPerHost: 1024,
|
||||
MaxIdleConnDuration: time.Minute,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
|
@ -62,171 +44,38 @@ var DefaultHTTPClientConfig = HTTPClientConfig{
|
|||
AllowRemote: true,
|
||||
}
|
||||
|
||||
// NewHTTPResponse creates a default HTTP response from pool
|
||||
func NewHTTPResponse() *HTTPResponse {
|
||||
return responsePool.Get().(*HTTPResponse)
|
||||
}
|
||||
|
||||
// ReleaseResponse returns the response to the pool
|
||||
func ReleaseResponse(resp *HTTPResponse) {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Clear all values to prevent data leakage
|
||||
resp.Status = 200 // Reset to default
|
||||
|
||||
// Clear headers
|
||||
for k := range resp.Headers {
|
||||
delete(resp.Headers, k)
|
||||
}
|
||||
|
||||
// Clear cookies
|
||||
resp.Cookies = resp.Cookies[:0] // Keep capacity but set length to 0
|
||||
|
||||
// Clear body
|
||||
resp.Body = nil
|
||||
|
||||
responsePool.Put(resp)
|
||||
}
|
||||
|
||||
// HTTPModuleInitFunc returns an initializer function for the HTTP module
|
||||
func HTTPModuleInitFunc() func(*luajit.State) error {
|
||||
return func(state *luajit.State) error {
|
||||
// Register the native Go function first
|
||||
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
|
||||
logger.Error("[HTTP Module] Failed to register __http_request function: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Set up default HTTP client configuration
|
||||
setupHTTPClientConfig(state)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// setupHTTPClientConfig configures HTTP client in Lua
|
||||
func setupHTTPClientConfig(state *luajit.State) {
|
||||
state.NewTable()
|
||||
|
||||
state.PushNumber(float64(DefaultHTTPClientConfig.MaxTimeout / time.Second))
|
||||
state.SetField(-2, "max_timeout")
|
||||
|
||||
state.PushNumber(float64(DefaultHTTPClientConfig.DefaultTimeout / time.Second))
|
||||
state.SetField(-2, "default_timeout")
|
||||
|
||||
state.PushNumber(float64(DefaultHTTPClientConfig.MaxResponseSize))
|
||||
state.SetField(-2, "max_response_size")
|
||||
|
||||
state.PushBoolean(DefaultHTTPClientConfig.AllowRemote)
|
||||
state.SetField(-2, "allow_remote")
|
||||
|
||||
state.SetGlobal("__http_client_config")
|
||||
}
|
||||
|
||||
// GetHTTPResponse extracts the HTTP response from Lua state
|
||||
func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) {
|
||||
response := NewHTTPResponse()
|
||||
|
||||
// Get response table
|
||||
state.GetGlobal("__http_responses")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
ReleaseResponse(response)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Check for response at thread index
|
||||
state.PushNumber(1)
|
||||
state.GetTable(-2)
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(2)
|
||||
ReleaseResponse(response)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Get status
|
||||
state.GetField(-1, "status")
|
||||
if state.IsNumber(-1) {
|
||||
response.Status = int(state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Get headers
|
||||
state.GetField(-1, "headers")
|
||||
if state.IsTable(-1) {
|
||||
// Iterate through headers table
|
||||
state.PushNil() // Start iteration
|
||||
for state.Next(-2) {
|
||||
// Stack has key at -2 and value at -1
|
||||
if state.IsString(-2) && state.IsString(-1) {
|
||||
key := state.ToString(-2)
|
||||
value := state.ToString(-1)
|
||||
response.Headers[key] = value
|
||||
}
|
||||
state.Pop(1) // Pop value, leave key for next iteration
|
||||
}
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Get cookies
|
||||
state.GetField(-1, "cookies")
|
||||
if state.IsTable(-1) {
|
||||
// Iterate through cookies array
|
||||
length := state.GetTableLength(-1)
|
||||
for i := 1; i <= length; i++ {
|
||||
state.PushNumber(float64(i))
|
||||
state.GetTable(-2)
|
||||
|
||||
if state.IsTable(-1) {
|
||||
cookie := extractCookie(state)
|
||||
if cookie != nil {
|
||||
response.Cookies = append(response.Cookies, cookie)
|
||||
}
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Clean up
|
||||
state.Pop(2) // Pop response table and __http_responses
|
||||
|
||||
return response, true
|
||||
}
|
||||
|
||||
// ApplyHTTPResponse applies an HTTP response to a fasthttp.RequestCtx
|
||||
func ApplyHTTPResponse(httpResp *HTTPResponse, ctx *fasthttp.RequestCtx) {
|
||||
// ApplyResponse applies a Response to a fasthttp.RequestCtx
|
||||
func ApplyResponse(resp *Response, ctx *fasthttp.RequestCtx) {
|
||||
// Set status code
|
||||
ctx.SetStatusCode(httpResp.Status)
|
||||
ctx.SetStatusCode(resp.Status)
|
||||
|
||||
// Set headers
|
||||
for name, value := range httpResp.Headers {
|
||||
for name, value := range resp.Headers {
|
||||
ctx.Response.Header.Set(name, value)
|
||||
}
|
||||
|
||||
// Set cookies
|
||||
for _, cookie := range httpResp.Cookies {
|
||||
for _, cookie := range resp.Cookies {
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
}
|
||||
|
||||
// Process the body based on its type
|
||||
if httpResp.Body == nil {
|
||||
if resp.Body == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get a buffer from the pool
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
// Set body based on type
|
||||
switch body := httpResp.Body.(type) {
|
||||
switch body := resp.Body.(type) {
|
||||
case string:
|
||||
ctx.SetBodyString(body)
|
||||
case []byte:
|
||||
ctx.SetBody(body)
|
||||
case map[string]any, []any, []float64, []string, []int:
|
||||
// Marshal JSON using a buffer from the pool
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
// Marshal JSON
|
||||
if err := json.NewEncoder(buf).Encode(body); err == nil {
|
||||
// Set content type if not already set
|
||||
if len(ctx.Response.Header.ContentType()) == 0 {
|
||||
|
@ -243,75 +92,6 @@ func ApplyHTTPResponse(httpResp *HTTPResponse, ctx *fasthttp.RequestCtx) {
|
|||
}
|
||||
}
|
||||
|
||||
// extractCookie grabs cookies from the Lua state
|
||||
func extractCookie(state *luajit.State) *fasthttp.Cookie {
|
||||
cookie := fasthttp.AcquireCookie()
|
||||
|
||||
// Get name
|
||||
state.GetField(-1, "name")
|
||||
if !state.IsString(-1) {
|
||||
state.Pop(1)
|
||||
fasthttp.ReleaseCookie(cookie)
|
||||
return nil // Name is required
|
||||
}
|
||||
cookie.SetKey(state.ToString(-1))
|
||||
state.Pop(1)
|
||||
|
||||
// Get value
|
||||
state.GetField(-1, "value")
|
||||
if state.IsString(-1) {
|
||||
cookie.SetValue(state.ToString(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Get path
|
||||
state.GetField(-1, "path")
|
||||
if state.IsString(-1) {
|
||||
cookie.SetPath(state.ToString(-1))
|
||||
} else {
|
||||
cookie.SetPath("/") // Default path
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Get domain
|
||||
state.GetField(-1, "domain")
|
||||
if state.IsString(-1) {
|
||||
cookie.SetDomain(state.ToString(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Get expires
|
||||
state.GetField(-1, "expires")
|
||||
if state.IsNumber(-1) {
|
||||
expiry := int64(state.ToNumber(-1))
|
||||
cookie.SetExpire(time.Unix(expiry, 0))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Get max age
|
||||
state.GetField(-1, "max_age")
|
||||
if state.IsNumber(-1) {
|
||||
cookie.SetMaxAge(int(state.ToNumber(-1)))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Get secure
|
||||
state.GetField(-1, "secure")
|
||||
if state.IsBoolean(-1) {
|
||||
cookie.SetSecure(state.ToBoolean(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Get http only
|
||||
state.GetField(-1, "http_only")
|
||||
if state.IsBoolean(-1) {
|
||||
cookie.SetHTTPOnly(state.ToBoolean(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
return cookie
|
||||
}
|
||||
|
||||
// httpRequest makes an HTTP request and returns the result to Lua
|
||||
func httpRequest(state *luajit.State) int {
|
||||
// Get method (required)
|
||||
|
@ -328,7 +108,7 @@ func httpRequest(state *luajit.State) int {
|
|||
}
|
||||
urlStr := state.ToString(2)
|
||||
|
||||
// Parse URL to check if it's valid and if it's allowed
|
||||
// Parse URL to check if it's valid
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
state.PushString("Invalid URL: " + err.Error())
|
||||
|
@ -336,38 +116,7 @@ func httpRequest(state *luajit.State) int {
|
|||
}
|
||||
|
||||
// Get client configuration
|
||||
var config HTTPClientConfig = DefaultHTTPClientConfig
|
||||
state.GetGlobal("__http_client_config")
|
||||
if !state.IsNil(-1) && state.IsTable(-1) {
|
||||
// Extract max timeout
|
||||
state.GetField(-1, "max_timeout")
|
||||
if state.IsNumber(-1) {
|
||||
config.MaxTimeout = time.Duration(state.ToNumber(-1)) * time.Second
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Extract default timeout
|
||||
state.GetField(-1, "default_timeout")
|
||||
if state.IsNumber(-1) {
|
||||
config.DefaultTimeout = time.Duration(state.ToNumber(-1)) * time.Second
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Extract max response size
|
||||
state.GetField(-1, "max_response_size")
|
||||
if state.IsNumber(-1) {
|
||||
config.MaxResponseSize = int64(state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Extract allow remote
|
||||
state.GetField(-1, "allow_remote")
|
||||
if state.IsBoolean(-1) {
|
||||
config.AllowRemote = state.ToBoolean(-1)
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
state.Pop(1)
|
||||
config := DefaultHTTPClientConfig
|
||||
|
||||
// Check if remote connections are allowed
|
||||
if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") {
|
||||
|
@ -409,6 +158,7 @@ func httpRequest(state *luajit.State) int {
|
|||
}
|
||||
|
||||
req.SetBody(buf.Bytes())
|
||||
req.Header.SetContentType("application/json")
|
||||
} else {
|
||||
state.PushString("Body must be a string or table")
|
||||
return -1
|
||||
|
@ -417,12 +167,7 @@ func httpRequest(state *luajit.State) int {
|
|||
|
||||
// Process options (headers, timeout, etc.)
|
||||
timeout := config.DefaultTimeout
|
||||
if state.GetTop() >= 4 && !state.IsNil(4) {
|
||||
if !state.IsTable(4) {
|
||||
state.PushString("Options must be a table")
|
||||
return -1
|
||||
}
|
||||
|
||||
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsTable(4) {
|
||||
// Process headers
|
||||
state.GetField(4, "headers")
|
||||
if state.IsTable(-1) {
|
||||
|
@ -454,23 +199,6 @@ func httpRequest(state *luajit.State) int {
|
|||
}
|
||||
state.Pop(1) // Pop timeout
|
||||
|
||||
// Set content type for POST/PUT if body is present and content-type not manually set
|
||||
if (method == "POST" || method == "PUT") && req.Body() != nil && req.Header.Peek("Content-Type") == nil {
|
||||
// Check if options specify content type
|
||||
state.GetField(4, "content_type")
|
||||
if state.IsString(-1) {
|
||||
req.Header.Set("Content-Type", state.ToString(-1))
|
||||
} else {
|
||||
// Default to JSON if body is a table, otherwise plain text
|
||||
if state.IsTable(3) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
} else {
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
}
|
||||
}
|
||||
state.Pop(1) // Pop content_type
|
||||
}
|
||||
|
||||
// Process query parameters
|
||||
state.GetField(4, "query")
|
||||
if state.IsTable(-1) {
|
||||
|
@ -480,7 +208,6 @@ func httpRequest(state *luajit.State) int {
|
|||
// Iterate through query params
|
||||
state.PushNil() // Start iteration
|
||||
for state.Next(-2) {
|
||||
// Stack now has key at -2 and value at -1
|
||||
if state.IsString(-2) {
|
||||
paramName := state.ToString(-2)
|
||||
|
||||
|
@ -571,3 +298,37 @@ func httpRequest(state *luajit.State) int {
|
|||
|
||||
return 1
|
||||
}
|
||||
|
||||
// generateToken creates a cryptographically secure random token
|
||||
func generateToken(state *luajit.State) int {
|
||||
// Get the length from the Lua arguments (default to 32)
|
||||
length := 32
|
||||
if state.GetTop() >= 1 && state.IsNumber(1) {
|
||||
length = int(state.ToNumber(1))
|
||||
}
|
||||
|
||||
// Enforce minimum length for security
|
||||
if length < 16 {
|
||||
length = 16
|
||||
}
|
||||
|
||||
// Generate secure random bytes
|
||||
tokenBytes := make([]byte, length)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
logger.Error("Failed to generate secure token: %v", err)
|
||||
state.PushString("")
|
||||
return 1 // Return empty string on error
|
||||
}
|
||||
|
||||
// Encode as base64
|
||||
token := base64.RawURLEncoding.EncodeToString(tokenBytes)
|
||||
|
||||
// Trim to requested length (base64 might be longer)
|
||||
if len(token) > length {
|
||||
token = token[:length]
|
||||
}
|
||||
|
||||
// Push the token to the Lua stack
|
||||
state.PushString(token)
|
||||
return 1 // One return value
|
||||
}
|
|
@ -6,6 +6,8 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
|
||||
"Moonshark/core/utils/logger"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
|
@ -15,61 +17,15 @@ type ModuleConfig struct {
|
|||
LibDirs []string // Additional library directories
|
||||
}
|
||||
|
||||
// ModuleInfo stores information about a loaded module
|
||||
type ModuleInfo struct {
|
||||
Name string
|
||||
Path string
|
||||
IsCore bool
|
||||
Bytecode []byte
|
||||
}
|
||||
|
||||
// ModuleLoader manages module loading and caching
|
||||
type ModuleLoader struct {
|
||||
config *ModuleConfig
|
||||
registry *ModuleRegistry
|
||||
pathCache map[string]string // Cache module paths for fast lookups
|
||||
bytecodeCache map[string][]byte // Cache of compiled bytecode
|
||||
debug bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// ModuleRegistry keeps track of Lua modules for file watching
|
||||
type ModuleRegistry struct {
|
||||
// Maps file paths to module names
|
||||
pathToModule sync.Map
|
||||
// Maps module names to file paths
|
||||
moduleToPath sync.Map
|
||||
}
|
||||
|
||||
// NewModuleRegistry creates a new module registry
|
||||
func NewModuleRegistry() *ModuleRegistry {
|
||||
return &ModuleRegistry{}
|
||||
}
|
||||
|
||||
// Register adds a module path to the registry
|
||||
func (r *ModuleRegistry) Register(path string, name string) {
|
||||
r.pathToModule.Store(path, name)
|
||||
r.moduleToPath.Store(name, path)
|
||||
}
|
||||
|
||||
// GetModuleName retrieves a module name by path
|
||||
func (r *ModuleRegistry) GetModuleName(path string) (string, bool) {
|
||||
value, ok := r.pathToModule.Load(path)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return value.(string), true
|
||||
}
|
||||
|
||||
// GetModulePath retrieves a path by module name
|
||||
func (r *ModuleRegistry) GetModulePath(name string) (string, bool) {
|
||||
value, ok := r.moduleToPath.Load(name)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return value.(string), true
|
||||
}
|
||||
|
||||
// NewModuleLoader creates a new module loader
|
||||
func NewModuleLoader(config *ModuleConfig) *ModuleLoader {
|
||||
if config == nil {
|
||||
|
@ -81,7 +37,6 @@ func NewModuleLoader(config *ModuleConfig) *ModuleLoader {
|
|||
|
||||
return &ModuleLoader{
|
||||
config: config,
|
||||
registry: NewModuleRegistry(),
|
||||
pathCache: make(map[string]string),
|
||||
bytecodeCache: make(map[string][]byte),
|
||||
debug: false,
|
||||
|
@ -100,6 +55,13 @@ func (l *ModuleLoader) SetScriptDir(dir string) {
|
|||
l.config.ScriptDir = dir
|
||||
}
|
||||
|
||||
// debugLog logs a message if debug mode is enabled
|
||||
func (l *ModuleLoader) debugLog(format string, args ...interface{}) {
|
||||
if l.debug {
|
||||
logger.Debug("ModuleLoader "+format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// SetupRequire configures the require system in a Lua state
|
||||
func (l *ModuleLoader) SetupRequire(state *luajit.State) error {
|
||||
l.mu.RLock()
|
||||
|
@ -207,6 +169,8 @@ func (l *ModuleLoader) PreloadModules(state *luajit.State) error {
|
|||
continue
|
||||
}
|
||||
|
||||
l.debugLog("Scanning directory: %s", absDir)
|
||||
|
||||
// Find all Lua files
|
||||
err = filepath.Walk(absDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil || info.IsDir() || !strings.HasSuffix(path, ".lua") {
|
||||
|
@ -223,19 +187,22 @@ func (l *ModuleLoader) PreloadModules(state *luajit.State) error {
|
|||
modName := strings.TrimSuffix(relPath, ".lua")
|
||||
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
|
||||
|
||||
l.debugLog("Found module: %s at %s", modName, path)
|
||||
|
||||
// Register in our caches
|
||||
l.pathCache[modName] = path
|
||||
l.registry.Register(path, modName)
|
||||
|
||||
// Load file content
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
l.debugLog("Failed to read module file: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Compile to bytecode
|
||||
bytecode, err := state.CompileBytecode(string(content), path)
|
||||
if err != nil {
|
||||
l.debugLog("Failed to compile module: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -354,11 +321,12 @@ func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) {
|
|||
// Clean path for proper comparison
|
||||
path = filepath.Clean(path)
|
||||
|
||||
// Try direct lookup from registry
|
||||
modName, found := l.registry.GetModuleName(path)
|
||||
if found {
|
||||
// Try direct lookup from cache
|
||||
for modName, modPath := range l.pathCache {
|
||||
if modPath == path {
|
||||
return modName, true
|
||||
}
|
||||
}
|
||||
|
||||
// Try to find by relative path from lib dirs
|
||||
for _, dir := range l.config.LibDirs {
|
||||
|
@ -373,7 +341,7 @@ func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) {
|
|||
}
|
||||
|
||||
if strings.HasSuffix(relPath, ".lua") {
|
||||
modName = strings.TrimSuffix(relPath, ".lua")
|
||||
modName := strings.TrimSuffix(relPath, ".lua")
|
||||
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
|
||||
return modName, true
|
||||
}
|
||||
|
@ -382,103 +350,6 @@ func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) {
|
|||
return "", false
|
||||
}
|
||||
|
||||
// ReloadModule reloads a module from disk
|
||||
func (l *ModuleLoader) ReloadModule(state *luajit.State, name string) (bool, error) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
// Get module path
|
||||
path, ok := l.registry.GetModulePath(name)
|
||||
if !ok {
|
||||
for modName, modPath := range l.pathCache {
|
||||
if modName == name {
|
||||
path = modPath
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !ok || path == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Invalidate module in Lua
|
||||
err := state.DoString(`
|
||||
package.loaded["` + name + `"] = nil
|
||||
__ready_modules["` + name + `"] = nil
|
||||
if package.preload then
|
||||
package.preload["` + name + `"] = nil
|
||||
end
|
||||
`)
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Check if file still exists
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
// File was deleted, just invalidate
|
||||
delete(l.pathCache, name)
|
||||
delete(l.bytecodeCache, name)
|
||||
l.registry.moduleToPath.Delete(name)
|
||||
l.registry.pathToModule.Delete(path)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Read updated file
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Compile to bytecode
|
||||
bytecode, err := state.CompileBytecode(string(content), path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Update cache
|
||||
l.bytecodeCache[name] = bytecode
|
||||
|
||||
// Load bytecode into state
|
||||
if err := state.LoadBytecode(bytecode, path); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Update preload
|
||||
luaCode := `
|
||||
local modname = "` + name + `"
|
||||
package.loaded[modname] = nil
|
||||
package.preload[modname] = ...
|
||||
__ready_modules[modname] = true
|
||||
`
|
||||
|
||||
if err := state.DoString(luaCode); err != nil {
|
||||
state.Pop(1) // Remove chunk from stack
|
||||
return false, err
|
||||
}
|
||||
|
||||
state.Pop(1) // Remove chunk from stack
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ResetModules clears non-core modules from package.loaded
|
||||
func (l *ModuleLoader) ResetModules(state *luajit.State) error {
|
||||
return state.DoString(`
|
||||
local core_modules = {
|
||||
string = true, table = true, math = true, os = true,
|
||||
package = true, io = true, coroutine = true, debug = true, _G = true
|
||||
}
|
||||
|
||||
for name in pairs(package.loaded) do
|
||||
if not core_modules[name] then
|
||||
package.loaded[name] = nil
|
||||
end
|
||||
end
|
||||
`)
|
||||
}
|
||||
|
||||
// escapeLuaString escapes special characters in a string for Lua
|
||||
func escapeLuaString(s string) string {
|
||||
replacer := strings.NewReplacer(
|
||||
|
|
76
core/runner/Response.go
Normal file
76
core/runner/Response.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package runner
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// Response represents a unified response from script execution
|
||||
type Response struct {
|
||||
// Basic properties
|
||||
Body any // Body content (any type)
|
||||
Metadata map[string]any // Additional metadata
|
||||
|
||||
// HTTP specific properties
|
||||
Status int // HTTP status code
|
||||
Headers map[string]string // HTTP headers
|
||||
Cookies []*fasthttp.Cookie // HTTP cookies
|
||||
|
||||
// Session information
|
||||
SessionID string // Session ID
|
||||
SessionData map[string]any // Session data
|
||||
SessionModified bool // Whether session was modified
|
||||
}
|
||||
|
||||
// Response pool to reduce allocations
|
||||
var responsePool = sync.Pool{
|
||||
New: func() any {
|
||||
return &Response{
|
||||
Status: 200,
|
||||
Headers: make(map[string]string, 8),
|
||||
Metadata: make(map[string]any, 8),
|
||||
Cookies: make([]*fasthttp.Cookie, 0, 4),
|
||||
SessionData: make(map[string]any, 8),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// NewResponse creates a new response object from the pool
|
||||
func NewResponse() *Response {
|
||||
return responsePool.Get().(*Response)
|
||||
}
|
||||
|
||||
// Release returns a response to the pool after cleaning it
|
||||
func ReleaseResponse(resp *Response) {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Reset fields to default values
|
||||
resp.Body = nil
|
||||
resp.Status = 200
|
||||
|
||||
// Clear maps
|
||||
for k := range resp.Headers {
|
||||
delete(resp.Headers, k)
|
||||
}
|
||||
|
||||
for k := range resp.Metadata {
|
||||
delete(resp.Metadata, k)
|
||||
}
|
||||
|
||||
for k := range resp.SessionData {
|
||||
delete(resp.SessionData, k)
|
||||
}
|
||||
|
||||
// Clear cookies
|
||||
resp.Cookies = resp.Cookies[:0]
|
||||
|
||||
// Reset session info
|
||||
resp.SessionID = ""
|
||||
resp.SessionModified = false
|
||||
|
||||
// Return to pool
|
||||
responsePool.Put(resp)
|
||||
}
|
|
@ -9,7 +9,6 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"Moonshark/core/runner/sandbox"
|
||||
"Moonshark/core/utils/logger"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
|
@ -29,17 +28,11 @@ type RunnerOption func(*Runner)
|
|||
// State wraps a Lua state with its sandbox
|
||||
type State struct {
|
||||
L *luajit.State // The Lua state
|
||||
sandbox *sandbox.Sandbox // Associated sandbox
|
||||
sandbox *Sandbox // Associated sandbox
|
||||
index int // Index for debugging
|
||||
inUse bool // Whether the state is currently in use
|
||||
}
|
||||
|
||||
// InitHook runs before executing a script
|
||||
type InitHook func(*luajit.State, *Context) error
|
||||
|
||||
// FinalizeHook runs after executing a script
|
||||
type FinalizeHook func(*luajit.State, *Context, any) error
|
||||
|
||||
// Runner runs Lua scripts using a pool of Lua states
|
||||
type Runner struct {
|
||||
states []*State // All states managed by this runner
|
||||
|
@ -49,8 +42,6 @@ type Runner struct {
|
|||
isRunning atomic.Bool // Whether the runner is active
|
||||
mu sync.RWMutex // Mutex for thread safety
|
||||
debug bool // Enable debug logging
|
||||
initHooks []InitHook // Hooks run before script execution
|
||||
finalizeHooks []FinalizeHook // Hooks run after script execution
|
||||
scriptDir string // Current script directory
|
||||
}
|
||||
|
||||
|
@ -83,28 +74,12 @@ func WithLibDirs(dirs ...string) RunnerOption {
|
|||
}
|
||||
}
|
||||
|
||||
// WithInitHook adds a hook to run before script execution
|
||||
func WithInitHook(hook InitHook) RunnerOption {
|
||||
return func(r *Runner) {
|
||||
r.initHooks = append(r.initHooks, hook)
|
||||
}
|
||||
}
|
||||
|
||||
// WithFinalizeHook adds a hook to run after script execution
|
||||
func WithFinalizeHook(hook FinalizeHook) RunnerOption {
|
||||
return func(r *Runner) {
|
||||
r.finalizeHooks = append(r.finalizeHooks, hook)
|
||||
}
|
||||
}
|
||||
|
||||
// NewRunner creates a new Runner with a pool of states
|
||||
func NewRunner(options ...RunnerOption) (*Runner, error) {
|
||||
// Default configuration
|
||||
runner := &Runner{
|
||||
poolSize: runtime.GOMAXPROCS(0),
|
||||
debug: false,
|
||||
initHooks: make([]InitHook, 0, 4),
|
||||
finalizeHooks: make([]FinalizeHook, 0, 4),
|
||||
}
|
||||
|
||||
// Apply options
|
||||
|
@ -121,6 +96,11 @@ func NewRunner(options ...RunnerOption) (*Runner, error) {
|
|||
runner.moduleLoader = NewModuleLoader(config)
|
||||
}
|
||||
|
||||
// Enable debug if requested
|
||||
if runner.debug {
|
||||
runner.moduleLoader.EnableDebug()
|
||||
}
|
||||
|
||||
// Initialize states and pool
|
||||
runner.states = make([]*State, runner.poolSize)
|
||||
runner.statePool = make(chan int, runner.poolSize)
|
||||
|
@ -144,7 +124,7 @@ func (r *Runner) debugLog(format string, args ...interface{}) {
|
|||
|
||||
// initializeStates creates and initializes all states in the pool
|
||||
func (r *Runner) initializeStates() error {
|
||||
r.debugLog("is initializing %d states", r.poolSize)
|
||||
r.debugLog("Initializing %d states", r.poolSize)
|
||||
|
||||
// Create all states
|
||||
for i := 0; i < r.poolSize; i++ {
|
||||
|
@ -174,39 +154,36 @@ func (r *Runner) createState(index int) (*State, error) {
|
|||
}
|
||||
|
||||
// Create sandbox
|
||||
sb := sandbox.NewSandbox()
|
||||
if r.debug && verbose {
|
||||
sb := NewSandbox()
|
||||
if r.debug {
|
||||
sb.EnableDebug()
|
||||
}
|
||||
|
||||
// Set up require system
|
||||
// Set up sandbox
|
||||
if err := sb.Setup(L); err != nil {
|
||||
L.Cleanup()
|
||||
L.Close()
|
||||
return nil, ErrInitFailed
|
||||
}
|
||||
|
||||
// Set up module loader
|
||||
if err := r.moduleLoader.SetupRequire(L); err != nil {
|
||||
L.Cleanup()
|
||||
L.Close()
|
||||
return nil, ErrInitFailed
|
||||
}
|
||||
|
||||
// Initialize all core modules from the registry
|
||||
if err := GlobalRegistry.Initialize(L, index); err != nil {
|
||||
L.Cleanup()
|
||||
L.Close()
|
||||
return nil, ErrInitFailed
|
||||
}
|
||||
|
||||
// Set up sandbox after core modules are initialized
|
||||
if err := sb.Setup(L, index); err != nil {
|
||||
L.Cleanup()
|
||||
L.Close()
|
||||
return nil, ErrInitFailed
|
||||
}
|
||||
|
||||
// Preload all modules
|
||||
// Preload modules
|
||||
if err := r.moduleLoader.PreloadModules(L); err != nil {
|
||||
L.Cleanup()
|
||||
L.Close()
|
||||
return nil, errors.New("failed to preload modules")
|
||||
}
|
||||
|
||||
if verbose {
|
||||
r.debugLog("Lua state %d initialized successfully", index)
|
||||
}
|
||||
|
||||
return &State{
|
||||
L: L,
|
||||
sandbox: sb,
|
||||
|
@ -215,8 +192,8 @@ func (r *Runner) createState(index int) (*State, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// Execute runs a script with context
|
||||
func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
|
||||
// Execute runs a script in a sandbox with context
|
||||
func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (*Response, error) {
|
||||
if !r.isRunning.Load() {
|
||||
return nil, ErrRunnerClosed
|
||||
}
|
||||
|
@ -263,70 +240,17 @@ func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context,
|
|||
}
|
||||
}()
|
||||
|
||||
// Run init hooks
|
||||
for _, hook := range r.initHooks {
|
||||
if err := hook(state.L, execCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get context values
|
||||
var ctxValues map[string]any
|
||||
if execCtx != nil {
|
||||
ctxValues = execCtx.Values
|
||||
}
|
||||
|
||||
// Execute in sandbox with optimized context handling
|
||||
var result any
|
||||
var err error
|
||||
|
||||
if execCtx != nil && execCtx.RequestCtx != nil {
|
||||
// Use OptimizedExecute directly with the full context if we have RequestCtx
|
||||
result, err = state.sandbox.OptimizedExecute(state.L, bytecode, &sandbox.Context{
|
||||
Values: ctxValues,
|
||||
RequestCtx: execCtx.RequestCtx,
|
||||
})
|
||||
} else {
|
||||
// Otherwise use standard Execute with just values
|
||||
result, err = state.sandbox.Execute(state.L, bytecode, ctxValues)
|
||||
}
|
||||
|
||||
// Execute in sandbox
|
||||
response, err := state.sandbox.Execute(state.L, bytecode, execCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Run finalize hooks
|
||||
for _, hook := range r.finalizeHooks {
|
||||
if hookErr := hook(state.L, execCtx, result); hookErr != nil {
|
||||
return nil, hookErr
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// Check for HTTP response if we don't have a RequestCtx or if we still have a result
|
||||
if execCtx == nil || execCtx.RequestCtx == nil || result != nil {
|
||||
httpResp, hasResponse := sandbox.GetHTTPResponse(state.L)
|
||||
if hasResponse {
|
||||
// Set result as body if not already set
|
||||
if httpResp.Body == nil {
|
||||
httpResp.Body = result
|
||||
}
|
||||
|
||||
// Apply directly to request context if available
|
||||
if execCtx != nil && execCtx.RequestCtx != nil {
|
||||
sandbox.ApplyHTTPResponse(httpResp, execCtx.RequestCtx)
|
||||
sandbox.ReleaseResponse(httpResp)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return httpResp, nil
|
||||
}
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Run executes a Lua script (convenience wrapper)
|
||||
func (r *Runner) Run(bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
|
||||
// Run executes a Lua script with immediate context
|
||||
func (r *Runner) Run(bytecode []byte, execCtx *Context, scriptPath string) (*Response, error) {
|
||||
return r.Execute(context.Background(), bytecode, execCtx, scriptPath)
|
||||
}
|
||||
|
||||
|
@ -362,6 +286,7 @@ cleanup:
|
|||
}
|
||||
}
|
||||
|
||||
r.debugLog("Runner closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -374,6 +299,8 @@ func (r *Runner) RefreshStates() error {
|
|||
return ErrRunnerClosed
|
||||
}
|
||||
|
||||
r.debugLog("Refreshing all states...")
|
||||
|
||||
// Drain all states from the pool
|
||||
for {
|
||||
select {
|
||||
|
@ -407,81 +334,6 @@ cleanup:
|
|||
return nil
|
||||
}
|
||||
|
||||
// AddInitHook adds a hook to be called before script execution
|
||||
func (r *Runner) AddInitHook(hook InitHook) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.initHooks = append(r.initHooks, hook)
|
||||
}
|
||||
|
||||
// AddFinalizeHook adds a hook to be called after script execution
|
||||
func (r *Runner) AddFinalizeHook(hook FinalizeHook) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.finalizeHooks = append(r.finalizeHooks, hook)
|
||||
}
|
||||
|
||||
// GetStateCount returns the number of initialized states
|
||||
func (r *Runner) GetStateCount() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
count := 0
|
||||
for _, state := range r.states {
|
||||
if state != nil {
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
// GetActiveStateCount returns the number of states currently in use
|
||||
func (r *Runner) GetActiveStateCount() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
count := 0
|
||||
for _, state := range r.states {
|
||||
if state != nil && state.inUse {
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
// GetModuleCount returns the number of loaded modules in the first available state
|
||||
func (r *Runner) GetModuleCount() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
if !r.isRunning.Load() {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Find first available state
|
||||
for _, state := range r.states {
|
||||
if state != nil && !state.inUse {
|
||||
// Execute a Lua snippet to count modules
|
||||
if res, err := state.L.ExecuteWithResult(`
|
||||
local count = 0
|
||||
for _ in pairs(package.loaded) do
|
||||
count = count + 1
|
||||
end
|
||||
return count
|
||||
`); err == nil {
|
||||
if num, ok := res.(float64); ok {
|
||||
return int(num)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// NotifyFileChanged alerts the runner about file changes
|
||||
func (r *Runner) NotifyFileChanged(filePath string) bool {
|
||||
r.debugLog("File change detected: %s", filePath)
|
||||
|
@ -513,9 +365,6 @@ func (r *Runner) RefreshModule(moduleName string) bool {
|
|||
|
||||
r.debugLog("Refreshing module: %s", moduleName)
|
||||
|
||||
// Check if it's a core module
|
||||
coreName, isCore := GlobalRegistry.MatchModuleName(moduleName)
|
||||
|
||||
success := true
|
||||
for _, state := range r.states {
|
||||
if state == nil || state.inUse {
|
||||
|
@ -525,16 +374,39 @@ func (r *Runner) RefreshModule(moduleName string) bool {
|
|||
// Invalidate module in Lua
|
||||
if err := state.L.DoString(`package.loaded["` + moduleName + `"] = nil`); err != nil {
|
||||
success = false
|
||||
continue
|
||||
}
|
||||
|
||||
// For core modules, reinitialize them
|
||||
if isCore {
|
||||
if err := GlobalRegistry.InitializeModule(state.L, coreName); err != nil {
|
||||
success = false
|
||||
}
|
||||
r.debugLog("Failed to invalidate module %s: %v", moduleName, err)
|
||||
}
|
||||
}
|
||||
|
||||
return success
|
||||
}
|
||||
|
||||
// GetStateCount returns the number of initialized states
|
||||
func (r *Runner) GetStateCount() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
count := 0
|
||||
for _, state := range r.states {
|
||||
if state != nil {
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
// GetActiveStateCount returns the number of states currently in use
|
||||
func (r *Runner) GetActiveStateCount() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
count := 0
|
||||
for _, state := range r.states {
|
||||
if state != nil && state.inUse {
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
|
345
core/runner/Sandbox.go
Normal file
345
core/runner/Sandbox.go
Normal file
|
@ -0,0 +1,345 @@
|
|||
package runner
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/valyala/bytebufferpool"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"Moonshark/core/utils/logger"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Error represents a simple error string
|
||||
type Error string
|
||||
|
||||
func (e Error) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// Error types
|
||||
var (
|
||||
ErrSandboxNotInitialized = Error("sandbox not initialized")
|
||||
)
|
||||
|
||||
// Sandbox provides a secure execution environment for Lua scripts
|
||||
type Sandbox struct {
|
||||
modules map[string]any
|
||||
debug bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSandbox creates a new sandbox environment
|
||||
func NewSandbox() *Sandbox {
|
||||
return &Sandbox{
|
||||
modules: make(map[string]any, 8),
|
||||
debug: false,
|
||||
}
|
||||
}
|
||||
|
||||
// EnableDebug turns on debug logging
|
||||
func (s *Sandbox) EnableDebug() {
|
||||
s.debug = true
|
||||
}
|
||||
|
||||
// debugLog logs a message if debug mode is enabled
|
||||
func (s *Sandbox) debugLog(format string, args ...interface{}) {
|
||||
if s.debug {
|
||||
logger.Debug("Sandbox "+format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// AddModule adds a module to the sandbox environment
|
||||
func (s *Sandbox) AddModule(name string, module any) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.modules[name] = module
|
||||
s.debugLog("Added module: %s", name)
|
||||
}
|
||||
|
||||
// Setup initializes the sandbox in a Lua state
|
||||
func (s *Sandbox) Setup(state *luajit.State) error {
|
||||
s.debugLog("Setting up sandbox...")
|
||||
|
||||
// Load the sandbox code
|
||||
if err := loadSandboxIntoState(state); err != nil {
|
||||
s.debugLog("Failed to load sandbox: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Register core functions
|
||||
if err := s.registerCoreFunctions(state); err != nil {
|
||||
s.debugLog("Failed to register core functions: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Register custom modules in the global environment
|
||||
s.mu.RLock()
|
||||
for name, module := range s.modules {
|
||||
s.debugLog("Registering module: %s", name)
|
||||
if err := state.PushValue(module); err != nil {
|
||||
s.mu.RUnlock()
|
||||
s.debugLog("Failed to register module %s: %v", name, err)
|
||||
return err
|
||||
}
|
||||
state.SetGlobal(name)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
s.debugLog("Sandbox setup complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
// registerCoreFunctions registers all built-in functions in the Lua state
|
||||
func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
|
||||
// Register HTTP functions
|
||||
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register utility functions
|
||||
if err := state.RegisterGoFunction("__generate_token", generateToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Additional registrations can be added here
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute runs a Lua script in the sandbox with the given context
|
||||
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
|
||||
s.debugLog("Executing script...")
|
||||
|
||||
// Create a response object
|
||||
response := NewResponse()
|
||||
|
||||
// Get a buffer for string operations
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
// Load bytecode
|
||||
if err := state.LoadBytecode(bytecode, "script"); err != nil {
|
||||
ReleaseResponse(response)
|
||||
s.debugLog("Failed to load bytecode: %v", err)
|
||||
return nil, fmt.Errorf("failed to load script: %w", err)
|
||||
}
|
||||
|
||||
// Initialize session data in Lua
|
||||
if ctx.SessionID != "" {
|
||||
// Set session ID
|
||||
state.PushString(ctx.SessionID)
|
||||
state.SetGlobal("__session_id")
|
||||
|
||||
// Set session data
|
||||
if err := state.PushTable(ctx.SessionData); err != nil {
|
||||
ReleaseResponse(response)
|
||||
s.debugLog("Failed to push session data: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
state.SetGlobal("__session_data")
|
||||
|
||||
// Reset modification flag
|
||||
state.PushBoolean(false)
|
||||
state.SetGlobal("__session_modified")
|
||||
} else {
|
||||
// Initialize empty session
|
||||
if err := state.DoString("__session_data = {}; __session_modified = false"); err != nil {
|
||||
s.debugLog("Failed to initialize empty session data: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set up context values for execution
|
||||
if err := state.PushTable(ctx.Values); err != nil {
|
||||
ReleaseResponse(response)
|
||||
s.debugLog("Failed to push context values: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the execution function
|
||||
state.GetGlobal("__execute_script")
|
||||
if !state.IsFunction(-1) {
|
||||
state.Pop(1) // Pop non-function
|
||||
ReleaseResponse(response)
|
||||
s.debugLog("__execute_script is not a function")
|
||||
return nil, ErrSandboxNotInitialized
|
||||
}
|
||||
|
||||
// Push function and context to stack
|
||||
state.PushCopy(-2) // bytecode
|
||||
state.PushCopy(-2) // context
|
||||
|
||||
// Remove duplicates
|
||||
state.Remove(-4)
|
||||
state.Remove(-3)
|
||||
|
||||
// Execute with 2 args, 1 result
|
||||
if err := state.Call(2, 1); err != nil {
|
||||
ReleaseResponse(response)
|
||||
s.debugLog("Execution failed: %v", err)
|
||||
return nil, fmt.Errorf("script execution failed: %w", err)
|
||||
}
|
||||
|
||||
// Set response body from result
|
||||
body, err := state.ToValue(-1)
|
||||
if err == nil {
|
||||
response.Body = body
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Extract HTTP response data from Lua state
|
||||
s.extractResponseData(state, response)
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// extractResponseData pulls response info from the Lua state
|
||||
func (s *Sandbox) extractResponseData(state *luajit.State, response *Response) {
|
||||
// Get HTTP response
|
||||
state.GetGlobal("__http_responses")
|
||||
if !state.IsNil(-1) && state.IsTable(-1) {
|
||||
state.PushNumber(1)
|
||||
state.GetTable(-2)
|
||||
|
||||
if !state.IsNil(-1) && state.IsTable(-1) {
|
||||
// Extract status
|
||||
state.GetField(-1, "status")
|
||||
if state.IsNumber(-1) {
|
||||
response.Status = int(state.ToNumber(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Extract headers
|
||||
state.GetField(-1, "headers")
|
||||
if state.IsTable(-1) {
|
||||
state.PushNil() // Start iteration
|
||||
for state.Next(-2) {
|
||||
if state.IsString(-2) && state.IsString(-1) {
|
||||
key := state.ToString(-2)
|
||||
value := state.ToString(-1)
|
||||
response.Headers[key] = value
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Extract cookies
|
||||
state.GetField(-1, "cookies")
|
||||
if state.IsTable(-1) {
|
||||
length := state.GetTableLength(-1)
|
||||
for i := 1; i <= length; i++ {
|
||||
state.PushNumber(float64(i))
|
||||
state.GetTable(-2)
|
||||
|
||||
if state.IsTable(-1) {
|
||||
s.extractCookie(state, response)
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Extract metadata if present
|
||||
state.GetField(-1, "metadata")
|
||||
if state.IsTable(-1) {
|
||||
table, err := state.ToTable(-1)
|
||||
if err == nil {
|
||||
for k, v := range table {
|
||||
response.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Extract session data
|
||||
state.GetGlobal("__session_modified")
|
||||
if state.IsBoolean(-1) && state.ToBoolean(-1) {
|
||||
response.SessionModified = true
|
||||
|
||||
// Get session ID
|
||||
state.GetGlobal("__session_id")
|
||||
if state.IsString(-1) {
|
||||
response.SessionID = state.ToString(-1)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Get session data
|
||||
state.GetGlobal("__session_data")
|
||||
if state.IsTable(-1) {
|
||||
sessionData, err := state.ToTable(-1)
|
||||
if err == nil {
|
||||
for k, v := range sessionData {
|
||||
response.SessionData[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
state.Pop(1)
|
||||
}
|
||||
|
||||
// extractCookie pulls cookie data from the current table on the stack
|
||||
func (s *Sandbox) extractCookie(state *luajit.State, response *Response) {
|
||||
cookie := fasthttp.AcquireCookie()
|
||||
|
||||
// Get name (required)
|
||||
state.GetField(-1, "name")
|
||||
if !state.IsString(-1) {
|
||||
state.Pop(1)
|
||||
fasthttp.ReleaseCookie(cookie)
|
||||
return
|
||||
}
|
||||
cookie.SetKey(state.ToString(-1))
|
||||
state.Pop(1)
|
||||
|
||||
// Get value
|
||||
state.GetField(-1, "value")
|
||||
if state.IsString(-1) {
|
||||
cookie.SetValue(state.ToString(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Get path
|
||||
state.GetField(-1, "path")
|
||||
if state.IsString(-1) {
|
||||
cookie.SetPath(state.ToString(-1))
|
||||
} else {
|
||||
cookie.SetPath("/") // Default
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Get domain
|
||||
state.GetField(-1, "domain")
|
||||
if state.IsString(-1) {
|
||||
cookie.SetDomain(state.ToString(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
// Get other parameters
|
||||
state.GetField(-1, "http_only")
|
||||
if state.IsBoolean(-1) {
|
||||
cookie.SetHTTPOnly(state.ToBoolean(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
state.GetField(-1, "secure")
|
||||
if state.IsBoolean(-1) {
|
||||
cookie.SetSecure(state.ToBoolean(-1))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
state.GetField(-1, "max_age")
|
||||
if state.IsNumber(-1) {
|
||||
cookie.SetMaxAge(int(state.ToNumber(-1)))
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
response.Cookies = append(response.Cookies, cookie)
|
||||
}
|
|
@ -1,240 +0,0 @@
|
|||
package runner
|
||||
|
||||
import (
|
||||
"Moonshark/core/runner/sandbox"
|
||||
"Moonshark/core/sessions"
|
||||
"Moonshark/core/utils/logger"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// SessionHandler handles session management for Lua scripts
|
||||
type SessionHandler struct {
|
||||
manager *sessions.SessionManager
|
||||
debugLog bool
|
||||
}
|
||||
|
||||
// NewSessionHandler creates a new session handler
|
||||
func NewSessionHandler(manager *sessions.SessionManager) *SessionHandler {
|
||||
return &SessionHandler{
|
||||
manager: manager,
|
||||
debugLog: false,
|
||||
}
|
||||
}
|
||||
|
||||
// EnableDebug enables debug logging
|
||||
func (h *SessionHandler) EnableDebug() {
|
||||
h.debugLog = true
|
||||
}
|
||||
|
||||
// WithSessionManager creates a RunnerOption to add session support
|
||||
func WithSessionManager(manager *sessions.SessionManager) RunnerOption {
|
||||
return func(r *Runner) {
|
||||
handler := NewSessionHandler(manager)
|
||||
r.AddInitHook(handler.preRequestHook)
|
||||
r.AddFinalizeHook(handler.postRequestHook)
|
||||
}
|
||||
}
|
||||
|
||||
// preRequestHook initializes session before script execution
|
||||
func (h *SessionHandler) preRequestHook(state *luajit.State, ctx *Context) error {
|
||||
if ctx == nil || ctx.Values["_request_cookies"] == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract cookies from context
|
||||
cookies, ok := ctx.Values["_request_cookies"].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the session ID from cookies
|
||||
cookieName := h.manager.CookieOptions()["name"].(string)
|
||||
var sessionID string
|
||||
|
||||
// Check if our session cookie exists
|
||||
if cookieValue, exists := cookies[cookieName]; exists {
|
||||
if strValue, ok := cookieValue.(string); ok && strValue != "" {
|
||||
sessionID = strValue
|
||||
}
|
||||
}
|
||||
|
||||
// Create new session if needed
|
||||
if sessionID == "" {
|
||||
session := h.manager.CreateSession()
|
||||
sessionID = session.ID
|
||||
}
|
||||
|
||||
// Store the session ID in the context
|
||||
ctx.Set("_session_id", sessionID)
|
||||
|
||||
// Get session data
|
||||
session := h.manager.GetSession(sessionID)
|
||||
sessionData := session.GetAll()
|
||||
|
||||
// Set session data in Lua state
|
||||
return SetSessionData(state, sessionID, sessionData)
|
||||
}
|
||||
|
||||
// postRequestHook handles session after script execution
|
||||
func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, result any) error {
|
||||
// Check if session was modified
|
||||
modifiedID, modifiedData, modified := GetSessionData(state)
|
||||
if !modified {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the original session ID from context
|
||||
var sessionID string
|
||||
if ctx != nil {
|
||||
if id, ok := ctx.Values["_session_id"].(string); ok {
|
||||
sessionID = id
|
||||
}
|
||||
}
|
||||
|
||||
// Use the original session ID if the modified one is empty
|
||||
if modifiedID == "" {
|
||||
modifiedID = sessionID
|
||||
}
|
||||
|
||||
if modifiedID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update session in manager
|
||||
session := h.manager.GetSession(modifiedID)
|
||||
session.Clear() // clear to sync deleted values
|
||||
for k, v := range modifiedData {
|
||||
session.Set(k, v)
|
||||
}
|
||||
|
||||
h.manager.SaveSession(session)
|
||||
|
||||
// Add session cookie to result if it's an HTTP response
|
||||
if httpResp, ok := result.(*sandbox.HTTPResponse); ok {
|
||||
h.addSessionCookie(httpResp, modifiedID)
|
||||
} else if ctx != nil && ctx.RequestCtx != nil {
|
||||
// Add cookie directly to the RequestCtx when result is not an HTTP response
|
||||
h.addSessionCookieToRequestCtx(ctx.RequestCtx, modifiedID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addSessionCookie adds a session cookie to an HTTP response
|
||||
func (h *SessionHandler) addSessionCookie(resp *sandbox.HTTPResponse, sessionID string) {
|
||||
// Get cookie options
|
||||
opts := h.manager.CookieOptions()
|
||||
|
||||
// Check if session cookie is already set
|
||||
cookieName := opts["name"].(string)
|
||||
for _, cookie := range resp.Cookies {
|
||||
if string(cookie.Key()) == cookieName {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Create and add cookie
|
||||
cookie := fasthttp.AcquireCookie()
|
||||
cookie.SetKey(cookieName)
|
||||
cookie.SetValue(sessionID)
|
||||
cookie.SetPath(opts["path"].(string))
|
||||
cookie.SetHTTPOnly(opts["http_only"].(bool))
|
||||
cookie.SetMaxAge(opts["max_age"].(int))
|
||||
|
||||
// Optional cookie parameters
|
||||
if domain, ok := opts["domain"].(string); ok && domain != "" {
|
||||
cookie.SetDomain(domain)
|
||||
}
|
||||
|
||||
if secure, ok := opts["secure"].(bool); ok {
|
||||
cookie.SetSecure(secure)
|
||||
}
|
||||
|
||||
resp.Cookies = append(resp.Cookies, cookie)
|
||||
}
|
||||
|
||||
func (h *SessionHandler) addSessionCookieToRequestCtx(ctx *fasthttp.RequestCtx, sessionID string) {
|
||||
// Get cookie options
|
||||
opts := h.manager.CookieOptions()
|
||||
cookieName := opts["name"].(string)
|
||||
|
||||
// Create cookie
|
||||
cookie := fasthttp.AcquireCookie()
|
||||
defer fasthttp.ReleaseCookie(cookie)
|
||||
|
||||
cookie.SetKey(cookieName)
|
||||
cookie.SetValue(sessionID)
|
||||
cookie.SetPath(opts["path"].(string))
|
||||
cookie.SetHTTPOnly(opts["http_only"].(bool))
|
||||
cookie.SetMaxAge(opts["max_age"].(int))
|
||||
|
||||
// Optional cookie parameters
|
||||
if domain, ok := opts["domain"].(string); ok && domain != "" {
|
||||
cookie.SetDomain(domain)
|
||||
}
|
||||
|
||||
if secure, ok := opts["secure"].(bool); ok {
|
||||
cookie.SetSecure(secure)
|
||||
}
|
||||
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
}
|
||||
|
||||
// GetSessionData extracts session data from Lua state
|
||||
func GetSessionData(state *luajit.State) (string, map[string]any, bool) {
|
||||
// Check if session was modified
|
||||
state.GetGlobal("__session_modified")
|
||||
modified := state.ToBoolean(-1)
|
||||
state.Pop(1)
|
||||
|
||||
if !modified {
|
||||
return "", nil, false
|
||||
}
|
||||
|
||||
// Get session ID
|
||||
state.GetGlobal("__session_id")
|
||||
sessionID := state.ToString(-1)
|
||||
state.Pop(1)
|
||||
|
||||
// Get session data
|
||||
state.GetGlobal("__session_data")
|
||||
if !state.IsTable(-1) {
|
||||
state.Pop(1)
|
||||
return sessionID, nil, false
|
||||
}
|
||||
|
||||
data, err := state.ToTable(-1)
|
||||
state.Pop(1)
|
||||
|
||||
if err != nil {
|
||||
logger.Error("Failed to extract session data: %v", err)
|
||||
return sessionID, nil, false
|
||||
}
|
||||
|
||||
return sessionID, data, true
|
||||
}
|
||||
|
||||
// SetSessionData sets session data in Lua state
|
||||
func SetSessionData(state *luajit.State, sessionID string, data map[string]any) error {
|
||||
// Set session ID
|
||||
state.PushString(sessionID)
|
||||
state.SetGlobal("__session_id")
|
||||
|
||||
// Set session data
|
||||
if data == nil {
|
||||
data = make(map[string]any)
|
||||
}
|
||||
|
||||
if err := state.PushTable(data); err != nil {
|
||||
return err
|
||||
}
|
||||
state.SetGlobal("__session_data")
|
||||
|
||||
// Reset modification flag
|
||||
state.PushBoolean(false)
|
||||
state.SetGlobal("__session_modified")
|
||||
|
||||
return nil
|
||||
}
|
|
@ -14,9 +14,6 @@ __ready_modules = {}
|
|||
__session_data = {}
|
||||
__session_id = nil
|
||||
__session_modified = false
|
||||
__env_system = {
|
||||
base_env = {}
|
||||
}
|
||||
|
||||
-- ======================================================================
|
||||
-- CORE SANDBOX FUNCTIONALITY
|
||||
|
@ -63,75 +60,6 @@ function __execute_script(fn, ctx)
|
|||
return result
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- MODULE LOADING SYSTEM
|
||||
-- ======================================================================
|
||||
|
||||
-- Setup environment-aware require function
|
||||
function __setup_require(env)
|
||||
-- Create require function specific to this environment
|
||||
env.require = function(modname)
|
||||
-- Check if already loaded
|
||||
if package.loaded[modname] then
|
||||
return package.loaded[modname]
|
||||
end
|
||||
|
||||
-- Check preloaded modules
|
||||
if __ready_modules[modname] then
|
||||
local loader = package.preload[modname]
|
||||
if loader then
|
||||
-- Set environment for loader
|
||||
setfenv(loader, env)
|
||||
|
||||
-- Execute and store result
|
||||
local result = loader()
|
||||
if result == nil then
|
||||
result = true
|
||||
end
|
||||
|
||||
package.loaded[modname] = result
|
||||
return result
|
||||
end
|
||||
end
|
||||
|
||||
-- Direct file load as fallback
|
||||
if __module_paths[modname] then
|
||||
local path = __module_paths[modname]
|
||||
local chunk, err = loadfile(path)
|
||||
if chunk then
|
||||
setfenv(chunk, env)
|
||||
local result = chunk()
|
||||
if result == nil then
|
||||
result = true
|
||||
end
|
||||
package.loaded[modname] = result
|
||||
return result
|
||||
end
|
||||
end
|
||||
|
||||
-- Full path search as last resort
|
||||
local errors = {}
|
||||
for path in package.path:gmatch("[^;]+") do
|
||||
local file_path = path:gsub("?", modname:gsub("%.", "/"))
|
||||
local chunk, err = loadfile(file_path)
|
||||
if chunk then
|
||||
setfenv(chunk, env)
|
||||
local result = chunk()
|
||||
if result == nil then
|
||||
result = true
|
||||
end
|
||||
package.loaded[modname] = result
|
||||
return result
|
||||
end
|
||||
table.insert(errors, "\tno file '" .. file_path .. "'")
|
||||
end
|
||||
|
||||
error("module '" .. modname .. "' not found:\n" .. table.concat(errors, "\n"), 2)
|
||||
end
|
||||
|
||||
return env
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- HTTP MODULE
|
||||
-- ======================================================================
|
||||
|
@ -166,6 +94,18 @@ local http = {
|
|||
http.set_header("Content-Type", content_type)
|
||||
end,
|
||||
|
||||
-- Set metadata (arbitrary data to be returned with response)
|
||||
set_metadata = function(key, value)
|
||||
if type(key) ~= "string" then
|
||||
error("http.set_metadata: key must be a string", 2)
|
||||
end
|
||||
|
||||
local resp = __http_responses[1] or {}
|
||||
resp.metadata = resp.metadata or {}
|
||||
resp.metadata[key] = value
|
||||
__http_responses[1] = resp
|
||||
end,
|
||||
|
||||
-- HTTP client submodule
|
||||
client = {
|
||||
-- Generic request function
|
||||
|
@ -213,10 +153,7 @@ local http = {
|
|||
-- Simple HEAD request
|
||||
head = function(url, options)
|
||||
options = options or {}
|
||||
local old_options = options
|
||||
options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query}
|
||||
local response = http.client.request("HEAD", url, nil, options)
|
||||
return response
|
||||
return http.client.request("HEAD", url, nil, options)
|
||||
end,
|
||||
|
||||
-- Simple OPTIONS request
|
||||
|
@ -271,7 +208,7 @@ local http = {
|
|||
-- Cookie module implementation
|
||||
local cookie = {
|
||||
-- Set a cookie
|
||||
set = function(name, value, options, ...)
|
||||
set = function(name, value, options)
|
||||
if type(name) ~= "string" then
|
||||
error("cookie.set: name must be a string", 2)
|
||||
end
|
||||
|
@ -281,20 +218,8 @@ local cookie = {
|
|||
resp.cookies = resp.cookies or {}
|
||||
__http_responses[1] = resp
|
||||
|
||||
-- Handle options as table or legacy params
|
||||
local opts = {}
|
||||
if type(options) == "table" then
|
||||
opts = options
|
||||
elseif options ~= nil then
|
||||
-- Legacy support: options is actually 'expires'
|
||||
opts.expires = options
|
||||
-- Check for other legacy params (4th-7th args)
|
||||
local args = {...}
|
||||
if args[1] then opts.path = args[1] end
|
||||
if args[2] then opts.domain = args[2] end
|
||||
if args[3] then opts.secure = args[3] end
|
||||
if args[4] ~= nil then opts.http_only = args[4] end
|
||||
end
|
||||
-- Handle options as table
|
||||
local opts = options or {}
|
||||
|
||||
-- Create cookie table
|
||||
local cookie = {
|
||||
|
@ -314,10 +239,8 @@ local cookie = {
|
|||
elseif opts.expires < 0 then
|
||||
cookie.expires = 1
|
||||
cookie.max_age = 0
|
||||
else
|
||||
-- opts.expires == 0: Session cookie
|
||||
-- Do nothing (omitting both expires and max-age creates a session cookie)
|
||||
end
|
||||
-- opts.expires == 0: Session cookie (omitting both expires and max-age)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -342,8 +265,13 @@ local cookie = {
|
|||
local env = getfenv(2)
|
||||
|
||||
-- Check if context exists and has cookies
|
||||
if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then
|
||||
return tostring(env.ctx.cookies[name])
|
||||
if env.ctx and env.ctx.cookies then
|
||||
return env.ctx.cookies[name]
|
||||
end
|
||||
|
||||
-- If context has request_cookies map
|
||||
if env.ctx and env.ctx._request_cookies then
|
||||
return env.ctx._request_cookies[name]
|
||||
end
|
||||
|
||||
return nil
|
||||
|
@ -372,7 +300,7 @@ local session = {
|
|||
error("session.get: key must be a string", 2)
|
||||
end
|
||||
|
||||
if __session_data and __session_data[key] then
|
||||
if __session_data and __session_data[key] ~= nil then
|
||||
return __session_data[key]
|
||||
end
|
||||
|
||||
|
@ -469,7 +397,7 @@ local csrf = {
|
|||
error("CSRF protection requires the session module", 2)
|
||||
end
|
||||
|
||||
local token = util.generate_token(length)
|
||||
local token = __generate_token(length)
|
||||
session.set(csrf.TOKEN_KEY, token)
|
||||
return token
|
||||
end,
|
||||
|
@ -501,7 +429,9 @@ local csrf = {
|
|||
local env = getfenv(2)
|
||||
|
||||
local form = nil
|
||||
if env.ctx and env.ctx.form then
|
||||
if env.ctx and env.ctx._request_form then
|
||||
form = env.ctx._request_form
|
||||
elseif env.ctx and env.ctx.form then
|
||||
form = env.ctx.form
|
||||
else
|
||||
return false
|
||||
|
@ -518,7 +448,6 @@ local csrf = {
|
|||
end
|
||||
|
||||
-- Constant-time comparison to prevent timing attacks
|
||||
-- This is safe since Lua strings are immutable
|
||||
if #token ~= #session_token then
|
||||
return false
|
||||
end
|
||||
|
@ -535,6 +464,90 @@ local csrf = {
|
|||
end
|
||||
}
|
||||
|
||||
-- ======================================================================
|
||||
-- UTIL MODULE
|
||||
-- ======================================================================
|
||||
|
||||
-- Utility module implementation
|
||||
local util = {
|
||||
-- Generate a token (wrapper around __generate_token)
|
||||
generate_token = function(length)
|
||||
return __generate_token(length or 32)
|
||||
end,
|
||||
|
||||
-- Simple JSON stringify (for when you just need a quick string)
|
||||
json_encode = function(value)
|
||||
if type(value) == "table" then
|
||||
local json = "{"
|
||||
local sep = ""
|
||||
for k, v in pairs(value) do
|
||||
json = json .. sep
|
||||
if type(k) == "number" then
|
||||
-- Array-like
|
||||
json = json .. util.json_encode(v)
|
||||
else
|
||||
-- Object-like
|
||||
json = json .. '"' .. k .. '":' .. util.json_encode(v)
|
||||
end
|
||||
sep = ","
|
||||
end
|
||||
return json .. "}"
|
||||
elseif type(value) == "string" then
|
||||
return '"' .. value:gsub('"', '\\"'):gsub('\n', '\\n') .. '"'
|
||||
elseif type(value) == "number" then
|
||||
return tostring(value)
|
||||
elseif type(value) == "boolean" then
|
||||
return value and "true" or "false"
|
||||
elseif value == nil then
|
||||
return "null"
|
||||
end
|
||||
return '"' .. tostring(value) .. '"'
|
||||
end,
|
||||
|
||||
-- Deep copy of tables
|
||||
deep_copy = function(obj)
|
||||
if type(obj) ~= 'table' then return obj end
|
||||
local res = {}
|
||||
for k, v in pairs(obj) do res[k] = util.deep_copy(v) end
|
||||
return res
|
||||
end,
|
||||
|
||||
-- Merge tables
|
||||
merge_tables = function(t1, t2)
|
||||
if type(t1) ~= 'table' or type(t2) ~= 'table' then
|
||||
error("Both arguments must be tables", 2)
|
||||
end
|
||||
|
||||
local result = util.deep_copy(t1)
|
||||
for k, v in pairs(t2) do
|
||||
if type(v) == 'table' and type(result[k]) == 'table' then
|
||||
result[k] = util.merge_tables(result[k], v)
|
||||
else
|
||||
result[k] = v
|
||||
end
|
||||
end
|
||||
return result
|
||||
end,
|
||||
|
||||
-- String utilities
|
||||
string = {
|
||||
-- Trim whitespace
|
||||
trim = function(s)
|
||||
return (s:gsub("^%s*(.-)%s*$", "%1"))
|
||||
end,
|
||||
|
||||
-- Split string
|
||||
split = function(s, delimiter)
|
||||
delimiter = delimiter or ","
|
||||
local result = {}
|
||||
for match in (s..delimiter):gmatch("(.-)"..delimiter) do
|
||||
table.insert(result, match)
|
||||
end
|
||||
return result
|
||||
end
|
||||
}
|
||||
}
|
||||
|
||||
-- ======================================================================
|
||||
-- REGISTER MODULES GLOBALLY
|
||||
-- ======================================================================
|
||||
|
@ -544,9 +557,4 @@ _G.http = http
|
|||
_G.cookie = cookie
|
||||
_G.session = session
|
||||
_G.csrf = csrf
|
||||
|
||||
-- Register modules in sandbox base environment
|
||||
__env_system.base_env.http = http
|
||||
__env_system.base_env.cookie = cookie
|
||||
__env_system.base_env.session = session
|
||||
__env_system.base_env.csrf = csrf
|
||||
_G.util = util
|
|
@ -1,98 +0,0 @@
|
|||
package sandbox
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
|
||||
"Moonshark/core/utils/logger"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
//go:embed lua/sandbox.lua
|
||||
var sandboxLua string
|
||||
|
||||
// InitializeSandbox loads the embedded Lua sandbox code into a Lua state
|
||||
func InitializeSandbox(state *luajit.State) error {
|
||||
// Compile once, use many times
|
||||
bytecodeOnce.Do(precompileSandbox)
|
||||
|
||||
if sandboxBytecode != nil {
|
||||
logger.Debug("Loading sandbox.lua from precompiled bytecode")
|
||||
return state.LoadAndRunBytecode(sandboxBytecode, "sandbox.lua")
|
||||
}
|
||||
|
||||
// Fallback if compilation failed
|
||||
logger.Warning("Using non-precompiled sandbox.lua (bytecode compilation failed)")
|
||||
return state.DoString(sandboxLua)
|
||||
}
|
||||
|
||||
// ModuleInitializers stores initializer functions for core modules
|
||||
type ModuleInitializers struct {
|
||||
HTTP func(*luajit.State) error
|
||||
Util func(*luajit.State) error
|
||||
Session func(*luajit.State) error
|
||||
Cookie func(*luajit.State) error
|
||||
CSRF func(*luajit.State) error
|
||||
}
|
||||
|
||||
// DefaultInitializers returns the default set of initializers
|
||||
func DefaultInitializers() *ModuleInitializers {
|
||||
return &ModuleInitializers{
|
||||
HTTP: func(state *luajit.State) error {
|
||||
// Register the native Go function first
|
||||
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
|
||||
logger.Error("[HTTP Module] Failed to register __http_request function: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Util: func(state *luajit.State) error {
|
||||
// Register util functions
|
||||
return RegisterModule(state, "util", UtilModuleFunctions())
|
||||
},
|
||||
Session: func(state *luajit.State) error {
|
||||
// Session doesn't need special initialization
|
||||
return nil
|
||||
},
|
||||
Cookie: func(state *luajit.State) error {
|
||||
// Cookie doesn't need special initialization
|
||||
return nil
|
||||
},
|
||||
CSRF: func(state *luajit.State) error {
|
||||
// CSRF doesn't need special initialization
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// InitializeAll initializes all modules in the Lua state
|
||||
func InitializeAll(state *luajit.State, initializers *ModuleInitializers) error {
|
||||
// Set up dependencies first
|
||||
if err := initializers.Util(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := initializers.HTTP(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Load the embedded sandbox code
|
||||
if err := InitializeSandbox(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the rest of the modules
|
||||
if err := initializers.Session(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := initializers.Cookie(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := initializers.CSRF(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,84 +0,0 @@
|
|||
package sandbox
|
||||
|
||||
import (
|
||||
"Moonshark/core/utils/logger"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// ModuleFunc returns a map of module functions
|
||||
type ModuleFunc func() map[string]luajit.GoFunction
|
||||
|
||||
// StateInitFunc initializes a module in a Lua state
|
||||
type StateInitFunc func(*luajit.State) error
|
||||
|
||||
// RegisterModule registers a map of functions as a Lua module
|
||||
func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.GoFunction) error {
|
||||
// Create a new table for the module
|
||||
state.NewTable()
|
||||
|
||||
// Add each function to the module table
|
||||
for fname, f := range funcs {
|
||||
state.PushString(fname)
|
||||
if err := state.PushGoFunction(f); err != nil {
|
||||
state.Pop(1) // Pop table
|
||||
return err
|
||||
}
|
||||
state.SetTable(-3)
|
||||
}
|
||||
|
||||
// Register the module globally
|
||||
state.SetGlobal(name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ModuleInitFunc creates a state initializer that registers multiple modules
|
||||
func ModuleInitFunc(modules map[string]ModuleFunc) StateInitFunc {
|
||||
return func(state *luajit.State) error {
|
||||
for name, moduleFunc := range modules {
|
||||
if err := RegisterModule(state, name, moduleFunc()); err != nil {
|
||||
logger.Error("Failed to register module %s: %v", name, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// CombineInitFuncs combines multiple state initializer functions into one
|
||||
func CombineInitFuncs(funcs ...StateInitFunc) StateInitFunc {
|
||||
return func(state *luajit.State) error {
|
||||
for _, f := range funcs {
|
||||
if f != nil {
|
||||
if err := f(state); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterLuaCode registers a Lua code snippet in a state
|
||||
func RegisterLuaCode(state *luajit.State, code string) error {
|
||||
return state.DoString(code)
|
||||
}
|
||||
|
||||
// RegisterLuaCodeInitFunc returns a StateInitFunc that registers Lua code
|
||||
func RegisterLuaCodeInitFunc(code string) StateInitFunc {
|
||||
return func(state *luajit.State) error {
|
||||
return RegisterLuaCode(state, code)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterLuaModuleInitFunc returns a StateInitFunc that registers a Lua module
|
||||
func RegisterLuaModuleInitFunc(name string, code string) StateInitFunc {
|
||||
return func(state *luajit.State) error {
|
||||
// Create name = {} global
|
||||
state.NewTable()
|
||||
state.SetGlobal(name)
|
||||
|
||||
// Then run the module code which will populate it
|
||||
return state.DoString(code)
|
||||
}
|
||||
}
|
|
@ -1,249 +0,0 @@
|
|||
package sandbox
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/valyala/bytebufferpool"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"Moonshark/core/utils/logger"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Global bytecode cache to improve performance
|
||||
var (
|
||||
sandboxBytecode []byte
|
||||
bytecodeOnce sync.Once
|
||||
)
|
||||
|
||||
// precompileSandbox compiles the sandbox.lua code to bytecode once
|
||||
func precompileSandbox() {
|
||||
tempState := luajit.New()
|
||||
if tempState == nil {
|
||||
logger.Error("Failed to create temporary Lua state for bytecode compilation")
|
||||
return
|
||||
}
|
||||
defer tempState.Close()
|
||||
defer tempState.Cleanup()
|
||||
|
||||
var err error
|
||||
sandboxBytecode, err = tempState.CompileBytecode(sandboxLua, "sandbox.lua")
|
||||
if err != nil {
|
||||
logger.Error("Failed to precompile sandbox.lua: %v", err)
|
||||
} else {
|
||||
logger.Debug("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(sandboxBytecode))
|
||||
}
|
||||
}
|
||||
|
||||
// Sandbox provides a secure execution environment for Lua scripts
|
||||
type Sandbox struct {
|
||||
modules map[string]any // Custom modules for environment
|
||||
debug bool // Enable debug output
|
||||
mu sync.RWMutex // Protects modules
|
||||
initializers *ModuleInitializers // Module initializers
|
||||
}
|
||||
|
||||
// NewSandbox creates a new sandbox environment
|
||||
func NewSandbox() *Sandbox {
|
||||
return &Sandbox{
|
||||
modules: make(map[string]any, 8), // Pre-allocate with reasonable capacity
|
||||
debug: false,
|
||||
initializers: DefaultInitializers(),
|
||||
}
|
||||
}
|
||||
|
||||
// EnableDebug turns on debug logging
|
||||
func (s *Sandbox) EnableDebug() {
|
||||
s.debug = true
|
||||
}
|
||||
|
||||
// debugLog logs a message if debug mode is enabled
|
||||
func (s *Sandbox) debugLog(format string, args ...interface{}) {
|
||||
if s.debug {
|
||||
logger.Debug("Sandbox "+format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// debugLogCont logs a continuation message if debug mode is enabled
|
||||
func (s *Sandbox) debugLogCont(format string, args ...interface{}) {
|
||||
if s.debug {
|
||||
logger.DebugCont(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// AddModule adds a module to the sandbox environment
|
||||
func (s *Sandbox) AddModule(name string, module any) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.modules[name] = module
|
||||
s.debugLog("Added module: %s", name)
|
||||
}
|
||||
|
||||
// Setup initializes the sandbox in a Lua state
|
||||
func (s *Sandbox) Setup(state *luajit.State, stateIndex int) error {
|
||||
verbose := stateIndex == 0
|
||||
|
||||
if verbose {
|
||||
s.debugLog("Setting up sandbox...")
|
||||
}
|
||||
|
||||
// Initialize modules with the embedded sandbox code
|
||||
if err := InitializeAll(state, s.initializers); err != nil {
|
||||
if verbose {
|
||||
s.debugLog("Failed to initialize sandbox: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Register custom modules in the global environment
|
||||
s.mu.RLock()
|
||||
for name, module := range s.modules {
|
||||
if verbose {
|
||||
s.debugLog("Registering module: %s", name)
|
||||
}
|
||||
if err := state.PushValue(module); err != nil {
|
||||
s.mu.RUnlock()
|
||||
if verbose {
|
||||
s.debugLog("Failed to register module %s: %v", name, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
state.SetGlobal(name)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
if verbose {
|
||||
s.debugLogCont("Sandbox setup complete")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute runs bytecode in the sandbox
|
||||
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]any) (any, error) {
|
||||
// Create a temporary context if we only have a map
|
||||
if ctx != nil {
|
||||
tempCtx := &Context{
|
||||
Values: ctx,
|
||||
}
|
||||
return s.OptimizedExecute(state, bytecode, tempCtx)
|
||||
}
|
||||
|
||||
// Just pass nil through if we have no context
|
||||
return s.OptimizedExecute(state, bytecode, nil)
|
||||
}
|
||||
|
||||
// Context represents execution context for a Lua script
|
||||
type Context struct {
|
||||
// Values stores any context values (route params, HTTP request info, etc.)
|
||||
Values map[string]any
|
||||
// RequestCtx for HTTP requests
|
||||
RequestCtx *fasthttp.RequestCtx
|
||||
}
|
||||
|
||||
// OptimizedExecute runs bytecode with a fasthttp context if available
|
||||
func (s *Sandbox) OptimizedExecute(state *luajit.State, bytecode []byte, ctx *Context) (any, error) {
|
||||
// Use a buffer from the pool for any string operations
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
// Load bytecode
|
||||
if err := state.LoadBytecode(bytecode, "script"); err != nil {
|
||||
s.debugLog("Failed to load bytecode: %v", err)
|
||||
return nil, fmt.Errorf("failed to load script: %w", err)
|
||||
}
|
||||
|
||||
// Prepare context values
|
||||
var ctxValues map[string]any
|
||||
if ctx != nil {
|
||||
ctxValues = ctx.Values
|
||||
} else {
|
||||
ctxValues = nil
|
||||
}
|
||||
|
||||
// Prepare context table
|
||||
if ctxValues != nil {
|
||||
state.CreateTable(0, len(ctxValues))
|
||||
for k, v := range ctxValues {
|
||||
state.PushString(k)
|
||||
if err := state.PushValue(v); err != nil {
|
||||
state.Pop(2) // Pop key and table
|
||||
s.debugLog("Failed to push context value %s: %v", k, err)
|
||||
return nil, fmt.Errorf("failed to prepare context: %w", err)
|
||||
}
|
||||
state.SetTable(-3)
|
||||
}
|
||||
} else {
|
||||
state.PushNil() // No context
|
||||
}
|
||||
|
||||
// Get execution function
|
||||
state.GetGlobal("__execute_script")
|
||||
if !state.IsFunction(-1) {
|
||||
state.Pop(2) // Pop context and non-function
|
||||
s.debugLog("__execute_script is not a function")
|
||||
return nil, fmt.Errorf("sandbox execution function not found")
|
||||
}
|
||||
|
||||
// Stack setup for call: __execute_script, bytecode function, context
|
||||
state.PushCopy(-3) // bytecode function (copy from -3)
|
||||
state.PushCopy(-3) // context (copy from -3)
|
||||
|
||||
// Clean up duplicate references
|
||||
state.Remove(-5) // Remove original bytecode function
|
||||
state.Remove(-4) // Remove original context
|
||||
|
||||
// Call with 2 args (function, context), 1 result
|
||||
if err := state.Call(2, 1); err != nil {
|
||||
s.debugLog("Execution failed: %v", err)
|
||||
return nil, fmt.Errorf("script execution failed: %w", err)
|
||||
}
|
||||
|
||||
// Get result
|
||||
result, err := state.ToValue(-1)
|
||||
state.Pop(1) // Pop result
|
||||
|
||||
// Check for HTTP response
|
||||
httpResponse, hasResponse := GetHTTPResponse(state)
|
||||
if hasResponse {
|
||||
// Add the script result as the response body
|
||||
httpResponse.Body = result
|
||||
|
||||
// If we have a fasthttp context, apply the response directly
|
||||
if ctx != nil && ctx.RequestCtx != nil {
|
||||
ApplyHTTPResponse(httpResponse, ctx.RequestCtx)
|
||||
ReleaseResponse(httpResponse)
|
||||
return nil, nil // No need to return response object
|
||||
}
|
||||
|
||||
return httpResponse, nil
|
||||
}
|
||||
|
||||
// If we have a fasthttp context and the result needs to be written directly
|
||||
if ctx != nil && ctx.RequestCtx != nil && (result != nil) {
|
||||
// For direct HTTP responses
|
||||
switch r := result.(type) {
|
||||
case string:
|
||||
ctx.RequestCtx.SetBodyString(r)
|
||||
case []byte:
|
||||
ctx.RequestCtx.SetBody(r)
|
||||
case map[string]any, []any:
|
||||
// JSON response
|
||||
ctx.RequestCtx.Response.Header.SetContentType("application/json")
|
||||
if err := json.NewEncoder(buf).Encode(r); err == nil {
|
||||
ctx.RequestCtx.SetBody(buf.Bytes())
|
||||
} else {
|
||||
ctx.RequestCtx.SetBodyString(fmt.Sprintf("%v", r))
|
||||
}
|
||||
default:
|
||||
// Default string conversion
|
||||
ctx.RequestCtx.SetBodyString(fmt.Sprintf("%v", r))
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
|
@ -1,58 +0,0 @@
|
|||
package sandbox
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
|
||||
"Moonshark/core/utils/logger"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// UtilModuleInitFunc returns an initializer for the util module
|
||||
func UtilModuleInitFunc() func(*luajit.State) error {
|
||||
return func(state *luajit.State) error {
|
||||
return RegisterModule(state, "util", UtilModuleFunctions())
|
||||
}
|
||||
}
|
||||
|
||||
// UtilModuleFunctions returns all functions for the util module
|
||||
func UtilModuleFunctions() map[string]luajit.GoFunction {
|
||||
return map[string]luajit.GoFunction{
|
||||
"generate_token": GenerateToken,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateToken creates a cryptographically secure random token
|
||||
func GenerateToken(s *luajit.State) int {
|
||||
// Get the length from the Lua arguments (default to 32)
|
||||
length := 32
|
||||
if s.GetTop() >= 1 && s.IsNumber(1) {
|
||||
length = int(s.ToNumber(1))
|
||||
}
|
||||
|
||||
// Enforce minimum length for security
|
||||
if length < 16 {
|
||||
length = 16
|
||||
}
|
||||
|
||||
// Generate secure random bytes
|
||||
tokenBytes := make([]byte, length)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
s.PushString("")
|
||||
logger.Error("Failed to generate secure token: %v", err)
|
||||
return 1 // Return empty string on error
|
||||
}
|
||||
|
||||
// Encode as base64
|
||||
token := base64.RawURLEncoding.EncodeToString(tokenBytes)
|
||||
|
||||
// Trim to requested length (base64 might be longer)
|
||||
if len(token) > length {
|
||||
token = token[:length]
|
||||
}
|
||||
|
||||
// Push the token to the Lua stack
|
||||
s.PushString(token)
|
||||
return 1 // One return value
|
||||
}
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"github.com/VictoriaMetrics/fastcache"
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -75,7 +76,7 @@ func (sm *SessionManager) GetSession(id string) *Session {
|
|||
|
||||
// Store back with updated timestamp
|
||||
updatedData, _ := json.Marshal(session)
|
||||
sm.cache.Set([]byte(id), updatedData) // Use updatedData, not data
|
||||
sm.cache.Set([]byte(id), updatedData)
|
||||
|
||||
return session
|
||||
}
|
||||
|
@ -141,5 +142,39 @@ func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, ht
|
|||
sm.cookieMaxAge = maxAge
|
||||
}
|
||||
|
||||
// GetSessionFromRequest extracts the session from a request context
|
||||
func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session {
|
||||
cookie := ctx.Request.Header.Cookie(sm.cookieName)
|
||||
if len(cookie) == 0 {
|
||||
// No session cookie, create a new session
|
||||
return sm.CreateSession()
|
||||
}
|
||||
|
||||
// Session cookie exists, get the session
|
||||
return sm.GetSession(string(cookie))
|
||||
}
|
||||
|
||||
// SaveSessionToResponse adds the session cookie to an HTTP response
|
||||
func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) {
|
||||
cookie := fasthttp.AcquireCookie()
|
||||
defer fasthttp.ReleaseCookie(cookie)
|
||||
|
||||
sm.mu.RLock()
|
||||
cookie.SetKey(sm.cookieName)
|
||||
cookie.SetValue(session.ID)
|
||||
cookie.SetPath(sm.cookiePath)
|
||||
cookie.SetHTTPOnly(sm.cookieHTTPOnly)
|
||||
cookie.SetMaxAge(sm.cookieMaxAge)
|
||||
|
||||
if sm.cookieDomain != "" {
|
||||
cookie.SetDomain(sm.cookieDomain)
|
||||
}
|
||||
|
||||
cookie.SetSecure(sm.cookieSecure)
|
||||
sm.mu.RUnlock()
|
||||
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
}
|
||||
|
||||
// GlobalSessionManager is the default session manager instance
|
||||
var GlobalSessionManager = NewSessionManager()
|
||||
|
|
Loading…
Reference in New Issue
Block a user