massive rewrite 1
This commit is contained in:
parent
f4b1e5fad7
commit
5ebcd97662
|
@ -184,8 +184,6 @@ func (s *Moonshark) initRunner() error {
|
||||||
runnerOpts := []runner.RunnerOption{
|
runnerOpts := []runner.RunnerOption{
|
||||||
runner.WithPoolSize(s.Config.Runner.PoolSize),
|
runner.WithPoolSize(s.Config.Runner.PoolSize),
|
||||||
runner.WithLibDirs(s.Config.Dirs.Libs...),
|
runner.WithLibDirs(s.Config.Dirs.Libs...),
|
||||||
runner.WithSessionManager(sessionManager),
|
|
||||||
http.WithCSRFProtection(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add debug option conditionally
|
// Add debug option conditionally
|
||||||
|
|
|
@ -2,18 +2,19 @@ package http
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"Moonshark/core/runner"
|
"Moonshark/core/runner"
|
||||||
luaCtx "Moonshark/core/runner/context"
|
|
||||||
"Moonshark/core/utils"
|
"Moonshark/core/utils"
|
||||||
"Moonshark/core/utils/logger"
|
"Moonshark/core/utils/logger"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
|
"errors"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Error for CSRF validation failure
|
||||||
|
var ErrCSRFValidationFailed = errors.New("CSRF token validation failed")
|
||||||
|
|
||||||
// ValidateCSRFToken checks if the CSRF token is valid for a request
|
// ValidateCSRFToken checks if the CSRF token is valid for a request
|
||||||
func ValidateCSRFToken(state *luajit.State, ctx *luaCtx.Context) bool {
|
func ValidateCSRFToken(ctx *runner.Context) bool {
|
||||||
// Only validate for form submissions
|
// Only validate for form submissions
|
||||||
method, ok := ctx.Get("method").(string)
|
method, ok := ctx.Get("method").(string)
|
||||||
if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") {
|
if !ok || (method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE") {
|
||||||
|
@ -34,87 +35,23 @@ func ValidateCSRFToken(state *luajit.State, ctx *luaCtx.Context) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get session token
|
// Get token from session
|
||||||
state.GetGlobal("session")
|
sessionData := ctx.SessionData
|
||||||
if state.IsNil(-1) {
|
if sessionData == nil {
|
||||||
state.Pop(1)
|
logger.Warning("CSRF validation failed: no session data")
|
||||||
logger.Warning("CSRF validation failed: session module not available")
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
state.GetField(-1, "get")
|
sessionToken, ok := sessionData["_csrf_token"].(string)
|
||||||
if !state.IsFunction(-1) {
|
if !ok || sessionToken == "" {
|
||||||
state.Pop(2)
|
|
||||||
logger.Warning("CSRF validation failed: session.get not available")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
state.PushCopy(-1) // Duplicate function
|
|
||||||
state.PushString("_csrf_token")
|
|
||||||
|
|
||||||
if err := state.Call(1, 1); err != nil {
|
|
||||||
state.Pop(3) // Pop error, function and session table
|
|
||||||
logger.Warning("CSRF validation failed: %v", err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(3) // Pop nil, function and session table
|
|
||||||
logger.Warning("CSRF validation failed: no token in session")
|
logger.Warning("CSRF validation failed: no token in session")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionToken := state.ToString(-1)
|
|
||||||
state.Pop(3) // Pop token, function and session table
|
|
||||||
|
|
||||||
// Constant-time comparison to prevent timing attacks
|
// Constant-time comparison to prevent timing attacks
|
||||||
return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1
|
return subtle.ConstantTimeCompare([]byte(formToken), []byte(sessionToken)) == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithCSRFProtection creates a runner option to add CSRF protection
|
|
||||||
func WithCSRFProtection() runner.RunnerOption {
|
|
||||||
return func(r *runner.Runner) {
|
|
||||||
r.AddInitHook(func(state *luajit.State, ctx *luaCtx.Context) error {
|
|
||||||
// Get request method
|
|
||||||
method, ok := ctx.Get("method").(string)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only validate for form submissions
|
|
||||||
if method != "POST" && method != "PUT" && method != "PATCH" && method != "DELETE" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for form data
|
|
||||||
form, ok := ctx.Get("form").(map[string]any)
|
|
||||||
if !ok || form == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate CSRF token
|
|
||||||
if !ValidateCSRFToken(state, ctx) {
|
|
||||||
return ErrCSRFValidationFailed
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error for CSRF validation failure
|
|
||||||
var ErrCSRFValidationFailed = &CSRFError{message: "CSRF token validation failed"}
|
|
||||||
|
|
||||||
// CSRFError represents a CSRF validation error
|
|
||||||
type CSRFError struct {
|
|
||||||
message string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error implements the error interface
|
|
||||||
func (e *CSRFError) Error() string {
|
|
||||||
return e.message
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandleCSRFError handles a CSRF validation error
|
// HandleCSRFError handles a CSRF validation error
|
||||||
func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) {
|
func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) {
|
||||||
method := string(ctx.Method())
|
method := string(ctx.Method())
|
||||||
|
@ -129,3 +66,39 @@ func HandleCSRFError(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig
|
||||||
errorHTML := utils.ForbiddenPage(errorConfig, path, errorMsg)
|
errorHTML := utils.ForbiddenPage(errorConfig, path, errorMsg)
|
||||||
ctx.SetBody([]byte(errorHTML))
|
ctx.SetBody([]byte(errorHTML))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GenerateCSRFToken creates a new CSRF token and stores it in the session
|
||||||
|
func GenerateCSRFToken(ctx *runner.Context, length int) (string, error) {
|
||||||
|
if length < 16 {
|
||||||
|
length = 16 // Minimum token length for security
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create secure random token
|
||||||
|
token, err := GenerateSecureToken(length)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store token in session
|
||||||
|
ctx.SessionData["_csrf_token"] = token
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCSRFToken retrieves the current CSRF token or generates a new one
|
||||||
|
func GetCSRFToken(ctx *runner.Context) (string, error) {
|
||||||
|
// Check if token already exists in session
|
||||||
|
if token, ok := ctx.SessionData["_csrf_token"].(string); ok && token != "" {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate new token
|
||||||
|
return GenerateCSRFToken(ctx, 32)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CSRFMiddleware validates CSRF tokens for state-changing requests
|
||||||
|
func CSRFMiddleware(ctx *runner.Context) error {
|
||||||
|
if !ValidateCSRFToken(ctx) {
|
||||||
|
return ErrCSRFValidationFailed
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,116 +0,0 @@
|
||||||
package http
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"mime/multipart"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Maximum form parse size (16MB)
|
|
||||||
const maxFormSize = 16 << 20
|
|
||||||
|
|
||||||
// Common errors
|
|
||||||
var (
|
|
||||||
ErrFormSizeTooLarge = errors.New("form size too large")
|
|
||||||
ErrInvalidFormType = errors.New("invalid form content type")
|
|
||||||
)
|
|
||||||
|
|
||||||
// ParseForm parses a POST request body into a map of values
|
|
||||||
// Supports both application/x-www-form-urlencoded and multipart/form-data content types
|
|
||||||
func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
|
|
||||||
// Only handle POST, PUT, PATCH
|
|
||||||
method := string(ctx.Method())
|
|
||||||
if method != "POST" && method != "PUT" && method != "PATCH" {
|
|
||||||
return make(map[string]any), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check content type
|
|
||||||
contentType := string(ctx.Request.Header.ContentType())
|
|
||||||
if contentType == "" {
|
|
||||||
return make(map[string]any), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
result := make(map[string]any)
|
|
||||||
|
|
||||||
// Check for content length to prevent DOS
|
|
||||||
if len(ctx.Request.Body()) > maxFormSize {
|
|
||||||
return nil, ErrFormSizeTooLarge
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle by content type
|
|
||||||
if strings.HasPrefix(contentType, "application/x-www-form-urlencoded") {
|
|
||||||
return parseURLEncodedForm(ctx)
|
|
||||||
} else if strings.HasPrefix(contentType, "multipart/form-data") {
|
|
||||||
return parseMultipartForm(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unrecognized content type
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseURLEncodedForm handles application/x-www-form-urlencoded forms
|
|
||||||
func parseURLEncodedForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
|
|
||||||
result := make(map[string]any)
|
|
||||||
|
|
||||||
// Process form values directly from PostArgs()
|
|
||||||
ctx.PostArgs().VisitAll(func(key, value []byte) {
|
|
||||||
keyStr := string(key)
|
|
||||||
valStr := string(value)
|
|
||||||
|
|
||||||
// Check if we already have this key
|
|
||||||
if existing, ok := result[keyStr]; ok {
|
|
||||||
// If it's already a slice, append
|
|
||||||
if existingSlice, ok := existing.([]string); ok {
|
|
||||||
result[keyStr] = append(existingSlice, valStr)
|
|
||||||
} else {
|
|
||||||
// Convert to slice and append
|
|
||||||
result[keyStr] = []string{existing.(string), valStr}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// New key
|
|
||||||
result[keyStr] = valStr
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseMultipartForm handles multipart/form-data forms
|
|
||||||
func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
|
|
||||||
result := make(map[string]any)
|
|
||||||
|
|
||||||
// Parse multipart form
|
|
||||||
form, err := ctx.MultipartForm()
|
|
||||||
if err != nil {
|
|
||||||
if err == multipart.ErrMessageTooLarge || strings.Contains(err.Error(), "too large") {
|
|
||||||
return nil, ErrFormSizeTooLarge
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process form values
|
|
||||||
for key, values := range form.Value {
|
|
||||||
if len(values) == 1 {
|
|
||||||
// Single value
|
|
||||||
result[key] = values[0]
|
|
||||||
} else if len(values) > 1 {
|
|
||||||
// Multiple values - store as string slice
|
|
||||||
strValues := make([]string, len(values))
|
|
||||||
copy(strValues, values)
|
|
||||||
result[key] = strValues
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We don't handle file uploads here - could be extended in the future
|
|
||||||
// if needed to support file uploads to Lua
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Usage:
|
|
||||||
// After parsing the form with ParseForm, you can add it to the context with:
|
|
||||||
// ctx.Set("form", formData)
|
|
||||||
//
|
|
||||||
// This makes the form data accessible in Lua as ctx.form.field_name
|
|
|
@ -1,43 +0,0 @@
|
||||||
package http
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"Moonshark/core/utils/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
// StatusColors for different status code ranges
|
|
||||||
const (
|
|
||||||
colorGreen = "\033[32m" // 2xx - Success
|
|
||||||
colorCyan = "\033[36m" // 3xx - Redirection
|
|
||||||
colorYellow = "\033[33m" // 4xx - Client Errors
|
|
||||||
colorRed = "\033[31m" // 5xx - Server Errors
|
|
||||||
colorReset = "\033[0m" // Reset color
|
|
||||||
colorGray = "\033[90m"
|
|
||||||
)
|
|
||||||
|
|
||||||
// LogRequest logs an HTTP request with custom formatting
|
|
||||||
func LogRequest(statusCode int, method, path string, duration time.Duration) {
|
|
||||||
statusColor := getStatusColor(statusCode)
|
|
||||||
|
|
||||||
// Use the logger's raw message writer to bypass the standard format
|
|
||||||
logger.LogRaw("%s%s%s %s%d %s%s %s %s(%v)%s",
|
|
||||||
colorGray, time.Now().Format(logger.TimeFormat()), colorReset,
|
|
||||||
statusColor, statusCode, method, colorReset, path, colorGray, duration, colorReset)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getStatusColor returns the ANSI color code for a status code
|
|
||||||
func getStatusColor(code int) string {
|
|
||||||
switch {
|
|
||||||
case code >= 200 && code < 300:
|
|
||||||
return colorGreen
|
|
||||||
case code >= 300 && code < 400:
|
|
||||||
return colorCyan
|
|
||||||
case code >= 400 && code < 500:
|
|
||||||
return colorYellow
|
|
||||||
case code >= 500:
|
|
||||||
return colorRed
|
|
||||||
default:
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,43 +0,0 @@
|
||||||
package http
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
)
|
|
||||||
|
|
||||||
// QueryToLua converts HTTP query parameters to a map that can be used with LuaJIT.
|
|
||||||
// Single value parameters are stored as strings.
|
|
||||||
// Multi-value parameters are converted to []any arrays.
|
|
||||||
func QueryToLua(ctx *fasthttp.RequestCtx) map[string]any {
|
|
||||||
result := make(map[string]any)
|
|
||||||
|
|
||||||
// Use a map to track keys that have multiple values
|
|
||||||
multiValueKeys := make(map[string]bool)
|
|
||||||
|
|
||||||
// Process all query args
|
|
||||||
ctx.QueryArgs().VisitAll(func(key, value []byte) {
|
|
||||||
keyStr := string(key)
|
|
||||||
valStr := string(value)
|
|
||||||
|
|
||||||
if _, exists := result[keyStr]; exists {
|
|
||||||
// This key already exists, convert to array if not already
|
|
||||||
if !multiValueKeys[keyStr] {
|
|
||||||
// First duplicate, convert existing value to array
|
|
||||||
multiValueKeys[keyStr] = true
|
|
||||||
result[keyStr] = []any{result[keyStr], valStr}
|
|
||||||
} else {
|
|
||||||
// Already an array, append
|
|
||||||
result[keyStr] = append(result[keyStr].([]any), valStr)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// New key
|
|
||||||
result[keyStr] = valStr
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// If we don't have any query parameters, return empty map
|
|
||||||
if len(result) == 0 {
|
|
||||||
return make(map[string]any)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
|
@ -2,21 +2,17 @@ package http
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"errors"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"Moonshark/core/metadata"
|
"Moonshark/core/metadata"
|
||||||
"Moonshark/core/routers"
|
"Moonshark/core/routers"
|
||||||
"Moonshark/core/runner"
|
"Moonshark/core/runner"
|
||||||
luaCtx "Moonshark/core/runner/context"
|
|
||||||
"Moonshark/core/runner/sandbox"
|
|
||||||
"Moonshark/core/sessions"
|
"Moonshark/core/sessions"
|
||||||
"Moonshark/core/utils"
|
"Moonshark/core/utils"
|
||||||
"Moonshark/core/utils/config"
|
"Moonshark/core/utils/config"
|
||||||
"Moonshark/core/utils/logger"
|
"Moonshark/core/utils/logger"
|
||||||
|
|
||||||
"github.com/goccy/go-json"
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -29,12 +25,14 @@ type Server struct {
|
||||||
loggingEnabled bool
|
loggingEnabled bool
|
||||||
debugMode bool
|
debugMode bool
|
||||||
config *config.Config
|
config *config.Config
|
||||||
|
sessionManager *sessions.SessionManager
|
||||||
errorConfig utils.ErrorPageConfig
|
errorConfig utils.ErrorPageConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new HTTP server with optimized connection settings
|
// New creates a new HTTP server with optimized connection settings
|
||||||
func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runner *runner.Runner,
|
func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter,
|
||||||
loggingEnabled bool, debugMode bool, overrideDir string, config *config.Config) *Server {
|
runner *runner.Runner, loggingEnabled bool, debugMode bool,
|
||||||
|
overrideDir string, config *config.Config) *Server {
|
||||||
|
|
||||||
server := &Server{
|
server := &Server{
|
||||||
luaRouter: luaRouter,
|
luaRouter: luaRouter,
|
||||||
|
@ -43,6 +41,7 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runne
|
||||||
loggingEnabled: loggingEnabled,
|
loggingEnabled: loggingEnabled,
|
||||||
debugMode: debugMode,
|
debugMode: debugMode,
|
||||||
config: config,
|
config: config,
|
||||||
|
sessionManager: sessions.GlobalSessionManager,
|
||||||
errorConfig: utils.ErrorPageConfig{
|
errorConfig: utils.ErrorPageConfig{
|
||||||
OverrideDir: overrideDir,
|
OverrideDir: overrideDir,
|
||||||
DebugMode: debugMode,
|
DebugMode: debugMode,
|
||||||
|
@ -55,7 +54,7 @@ func New(luaRouter *routers.LuaRouter, staticRouter *routers.StaticRouter, runne
|
||||||
Name: "Moonshark/" + metadata.Version,
|
Name: "Moonshark/" + metadata.Version,
|
||||||
ReadTimeout: 30 * time.Second,
|
ReadTimeout: 30 * time.Second,
|
||||||
WriteTimeout: 30 * time.Second,
|
WriteTimeout: 30 * time.Second,
|
||||||
MaxRequestBodySize: 16 << 20, // 16MB - consistent with Forms.go
|
MaxRequestBodySize: 16 << 20, // 16MB
|
||||||
DisableKeepalive: false,
|
DisableKeepalive: false,
|
||||||
TCPKeepalive: true,
|
TCPKeepalive: true,
|
||||||
TCPKeepalivePeriod: 60 * time.Second,
|
TCPKeepalivePeriod: 60 * time.Second,
|
||||||
|
@ -99,7 +98,7 @@ func (s *Server) handleRequest(ctx *fasthttp.RequestCtx) {
|
||||||
// Process the request
|
// Process the request
|
||||||
s.processRequest(ctx)
|
s.processRequest(ctx)
|
||||||
|
|
||||||
// Log the request with our custom format
|
// Log the request
|
||||||
if s.loggingEnabled {
|
if s.loggingEnabled {
|
||||||
duration := time.Since(start)
|
duration := time.Since(start)
|
||||||
LogRequest(ctx.Response.StatusCode(), method, path, duration)
|
LogRequest(ctx.Response.StatusCode(), method, path, duration)
|
||||||
|
@ -153,48 +152,25 @@ func (s *Server) processRequest(ctx *fasthttp.RequestCtx) {
|
||||||
ctx.SetBody([]byte(utils.NotFoundPage(s.errorConfig, path)))
|
ctx.SetBody([]byte(utils.NotFoundPage(s.errorConfig, path)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleMethodNotAllowed responds with a 405 Method Not Allowed error
|
|
||||||
func HandleMethodNotAllowed(ctx *fasthttp.RequestCtx, errorConfig utils.ErrorPageConfig) {
|
|
||||||
path := string(ctx.Path())
|
|
||||||
ctx.SetContentType("text/html; charset=utf-8")
|
|
||||||
ctx.SetStatusCode(fasthttp.StatusMethodNotAllowed)
|
|
||||||
ctx.SetBody([]byte(utils.MethodNotAllowedPage(errorConfig, path)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleLuaRoute executes a Lua route
|
// handleLuaRoute executes a Lua route
|
||||||
// Updated handleLuaRoute function to handle sessions
|
|
||||||
func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scriptPath string, params *routers.Params) {
|
func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scriptPath string, params *routers.Params) {
|
||||||
luaCtx := luaCtx.NewHTTPContext(ctx)
|
// Create context for Lua execution
|
||||||
|
luaCtx := runner.NewHTTPContext(ctx)
|
||||||
defer luaCtx.Release()
|
defer luaCtx.Release()
|
||||||
|
|
||||||
method := string(ctx.Method())
|
method := string(ctx.Method())
|
||||||
path := string(ctx.Path())
|
path := string(ctx.Path())
|
||||||
host := string(ctx.Host())
|
host := string(ctx.Host())
|
||||||
|
|
||||||
// Set up context
|
// Set up additional context values
|
||||||
luaCtx.Set("method", method)
|
luaCtx.Set("method", method)
|
||||||
luaCtx.Set("path", path)
|
luaCtx.Set("path", path)
|
||||||
luaCtx.Set("host", host)
|
luaCtx.Set("host", host)
|
||||||
|
|
||||||
// Headers
|
// Initialize session
|
||||||
headerMap := make(map[string]any)
|
session := s.sessionManager.GetSessionFromRequest(ctx)
|
||||||
ctx.Request.Header.VisitAll(func(key, value []byte) {
|
luaCtx.SessionID = session.ID
|
||||||
headerMap[string(key)] = string(value)
|
luaCtx.SessionData = session.GetAll()
|
||||||
})
|
|
||||||
luaCtx.Set("headers", headerMap)
|
|
||||||
|
|
||||||
// Cookies
|
|
||||||
cookieMap := make(map[string]any)
|
|
||||||
ctx.Request.Header.VisitAllCookie(func(key, value []byte) {
|
|
||||||
cookieMap[string(key)] = string(value)
|
|
||||||
})
|
|
||||||
if len(cookieMap) > 0 {
|
|
||||||
luaCtx.Set("cookies", cookieMap)
|
|
||||||
luaCtx.Set("_request_cookies", cookieMap) // For backward compatibility
|
|
||||||
} else {
|
|
||||||
luaCtx.Set("cookies", make(map[string]any))
|
|
||||||
luaCtx.Set("_request_cookies", make(map[string]any))
|
|
||||||
}
|
|
||||||
|
|
||||||
// URL parameters
|
// URL parameters
|
||||||
if params.Count > 0 {
|
if params.Count > 0 {
|
||||||
|
@ -207,11 +183,7 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
||||||
luaCtx.Set("params", make(map[string]any))
|
luaCtx.Set("params", make(map[string]any))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query parameters
|
// Parse form data for POST/PUT/PATCH requests
|
||||||
queryMap := QueryToLua(ctx)
|
|
||||||
luaCtx.Set("query", queryMap)
|
|
||||||
|
|
||||||
// Form data
|
|
||||||
if method == "POST" || method == "PUT" || method == "PATCH" {
|
if method == "POST" || method == "PUT" || method == "PATCH" {
|
||||||
formData, err := ParseForm(ctx)
|
formData, err := ParseForm(ctx)
|
||||||
if err == nil && len(formData) > 0 {
|
if err == nil && len(formData) > 0 {
|
||||||
|
@ -226,40 +198,26 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
||||||
luaCtx.Set("form", make(map[string]any))
|
luaCtx.Set("form", make(map[string]any))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Session handling
|
// CSRF middleware for state-changing requests
|
||||||
cookieOpts := sessions.GlobalSessionManager.CookieOptions()
|
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
|
||||||
cookieName := cookieOpts["name"].(string)
|
if !ValidateCSRFToken(luaCtx) {
|
||||||
sessionCookie := ctx.Request.Header.Cookie(cookieName)
|
HandleCSRFError(ctx, s.errorConfig)
|
||||||
|
return
|
||||||
var sessionID string
|
}
|
||||||
if sessionCookie != nil {
|
|
||||||
sessionID = string(sessionCookie)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get or create session
|
|
||||||
var session *sessions.Session
|
|
||||||
if sessionID != "" {
|
|
||||||
session = sessions.GlobalSessionManager.GetSession(sessionID)
|
|
||||||
} else {
|
|
||||||
session = sessions.GlobalSessionManager.CreateSession()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set session in context
|
|
||||||
luaCtx.Session = session
|
|
||||||
|
|
||||||
// Execute Lua script
|
// Execute Lua script
|
||||||
result, err := s.luaRunner.Run(bytecode, luaCtx, scriptPath)
|
response, err := s.luaRunner.Run(bytecode, luaCtx, scriptPath)
|
||||||
|
|
||||||
// Special handling for CSRF error
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if csrfErr, ok := err.(*CSRFError); ok {
|
logger.Error("Error executing Lua route: %v", err)
|
||||||
logger.Warning("CSRF error executing Lua route: %v", csrfErr)
|
|
||||||
|
// Special handling for specific errors
|
||||||
|
if errors.Is(err, ErrCSRFValidationFailed) {
|
||||||
HandleCSRFError(ctx, s.errorConfig)
|
HandleCSRFError(ctx, s.errorConfig)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normal error handling
|
// General error handling
|
||||||
logger.Error("Error executing Lua route: %v", err)
|
|
||||||
ctx.SetContentType("text/html; charset=utf-8")
|
ctx.SetContentType("text/html; charset=utf-8")
|
||||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
errorHTML := utils.InternalErrorPage(s.errorConfig, path, err.Error())
|
errorHTML := utils.InternalErrorPage(s.errorConfig, path, err.Error())
|
||||||
|
@ -267,129 +225,21 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle session updates if needed
|
// Save session if modified
|
||||||
if luaCtx.SessionModified {
|
if response.SessionModified {
|
||||||
sessions.GlobalSessionManager.SaveSession(luaCtx.Session)
|
// Update session data
|
||||||
|
for k, v := range response.SessionData {
|
||||||
// Set session cookie
|
session.Set(k, v)
|
||||||
cookie := fasthttp.AcquireCookie()
|
|
||||||
cookie.SetKey(cookieName)
|
|
||||||
cookie.SetValue(luaCtx.Session.ID)
|
|
||||||
cookie.SetPath(cookieOpts["path"].(string))
|
|
||||||
|
|
||||||
if domain, ok := cookieOpts["domain"].(string); ok && domain != "" {
|
|
||||||
cookie.SetDomain(domain)
|
|
||||||
}
|
}
|
||||||
|
s.sessionManager.SaveSession(session)
|
||||||
if maxAge, ok := cookieOpts["max_age"].(int); ok {
|
s.sessionManager.ApplySessionCookie(ctx, session)
|
||||||
cookie.SetMaxAge(maxAge)
|
|
||||||
}
|
|
||||||
|
|
||||||
cookie.SetSecure(cookieOpts["secure"].(bool))
|
|
||||||
cookie.SetHTTPOnly(cookieOpts["http_only"].(bool))
|
|
||||||
|
|
||||||
ctx.Response.Header.SetCookie(cookie)
|
|
||||||
fasthttp.ReleaseCookie(cookie)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we got a non-nil result, write it to the response
|
// Apply response to HTTP context
|
||||||
if result != nil {
|
runner.ApplyResponse(response, ctx)
|
||||||
writeResponse(ctx, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Content types for responses
|
// Release the response when done
|
||||||
const (
|
runner.ReleaseResponse(response)
|
||||||
contentTypeJSON = "application/json"
|
|
||||||
contentTypePlain = "text/plain"
|
|
||||||
)
|
|
||||||
|
|
||||||
// writeResponse writes the Lua result to the HTTP response
|
|
||||||
func writeResponse(ctx *fasthttp.RequestCtx, result any) {
|
|
||||||
if result == nil {
|
|
||||||
ctx.SetStatusCode(fasthttp.StatusNoContent)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// First check the raw type of the result for strong type identification
|
|
||||||
// Sometimes type assertions don't work as expected with interface values
|
|
||||||
resultType := fmt.Sprintf("%T", result)
|
|
||||||
|
|
||||||
// Strong check for HTTP response
|
|
||||||
if strings.Contains(resultType, "HTTPResponse") || strings.Contains(resultType, "sandbox.HTTPResponse") {
|
|
||||||
httpResp, ok := result.(*sandbox.HTTPResponse)
|
|
||||||
if ok {
|
|
||||||
defer sandbox.ReleaseResponse(httpResp)
|
|
||||||
|
|
||||||
// Set response headers
|
|
||||||
for name, value := range httpResp.Headers {
|
|
||||||
ctx.Response.Header.Set(name, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set cookies
|
|
||||||
for _, cookie := range httpResp.Cookies {
|
|
||||||
ctx.Response.Header.SetCookie(cookie)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set status code
|
|
||||||
ctx.SetStatusCode(httpResp.Status)
|
|
||||||
|
|
||||||
// Process the body based on its type
|
|
||||||
if httpResp.Body == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Continue with the body only
|
|
||||||
result = httpResp.Body
|
|
||||||
} else {
|
|
||||||
// We identified it as HTTPResponse but couldn't convert it
|
|
||||||
// This is a programming error
|
|
||||||
logger.Error("Found HTTPResponse type but failed to convert: %v", resultType)
|
|
||||||
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if it's a map (table) or array - return as JSON
|
|
||||||
isJSON := false
|
|
||||||
switch result.(type) {
|
|
||||||
case map[string]any, []any, []float64, []string, []int:
|
|
||||||
isJSON = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if isJSON {
|
|
||||||
setContentTypeIfMissing(ctx, contentTypeJSON)
|
|
||||||
data, err := json.Marshal(result)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to marshal response: %v", err)
|
|
||||||
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ctx.SetBody(data)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle string and byte slice cases directly
|
|
||||||
switch r := result.(type) {
|
|
||||||
case string:
|
|
||||||
setContentTypeIfMissing(ctx, contentTypePlain)
|
|
||||||
ctx.SetBodyString(r)
|
|
||||||
return
|
|
||||||
case []byte:
|
|
||||||
setContentTypeIfMissing(ctx, contentTypePlain)
|
|
||||||
ctx.SetBody(r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we reach here, it's an unexpected type - convert to string as a last resort
|
|
||||||
setContentTypeIfMissing(ctx, contentTypePlain)
|
|
||||||
ctx.SetBodyString(fmt.Sprintf("%v", result))
|
|
||||||
}
|
|
||||||
|
|
||||||
func setContentTypeIfMissing(ctx *fasthttp.RequestCtx, contentType string) {
|
|
||||||
if len(ctx.Response.Header.ContentType()) == 0 {
|
|
||||||
ctx.SetContentType(contentType)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleDebugStats displays debug statistics
|
// handleDebugStats displays debug statistics
|
||||||
|
@ -399,12 +249,14 @@ func (s *Server) handleDebugStats(ctx *fasthttp.RequestCtx) {
|
||||||
|
|
||||||
// Add component stats
|
// Add component stats
|
||||||
routeCount, bytecodeBytes := s.luaRouter.GetRouteStats()
|
routeCount, bytecodeBytes := s.luaRouter.GetRouteStats()
|
||||||
moduleCount := s.luaRunner.GetModuleCount()
|
//stateCount := s.luaRunner.GetStateCount()
|
||||||
|
//activeStates := s.luaRunner.GetActiveStateCount()
|
||||||
|
|
||||||
stats.Components = utils.ComponentStats{
|
stats.Components = utils.ComponentStats{
|
||||||
RouteCount: routeCount,
|
RouteCount: routeCount,
|
||||||
BytecodeBytes: bytecodeBytes,
|
BytecodeBytes: bytecodeBytes,
|
||||||
ModuleCount: moduleCount,
|
//StatesCount: stateCount,
|
||||||
|
//ActiveStates: activeStates,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate HTML page
|
// Generate HTML page
|
||||||
|
|
206
core/http/Utils.go
Normal file
206
core/http/Utils.go
Normal file
|
@ -0,0 +1,206 @@
|
||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"mime/multipart"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"Moonshark/core/utils/logger"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LogRequest logs an HTTP request with its status code and duration
|
||||||
|
func LogRequest(statusCode int, method, path string, duration time.Duration) {
|
||||||
|
var statusColor, resetColor, methodColor string
|
||||||
|
|
||||||
|
// Status code colors
|
||||||
|
if statusCode >= 200 && statusCode < 300 {
|
||||||
|
statusColor = "\u001b[32m" // Green for 2xx
|
||||||
|
} else if statusCode >= 300 && statusCode < 400 {
|
||||||
|
statusColor = "\u001b[36m" // Cyan for 3xx
|
||||||
|
} else if statusCode >= 400 && statusCode < 500 {
|
||||||
|
statusColor = "\u001b[33m" // Yellow for 4xx
|
||||||
|
} else {
|
||||||
|
statusColor = "\u001b[31m" // Red for 5xx and others
|
||||||
|
}
|
||||||
|
|
||||||
|
// Method colors
|
||||||
|
switch method {
|
||||||
|
case "GET":
|
||||||
|
methodColor = "\u001b[32m" // Green
|
||||||
|
case "POST":
|
||||||
|
methodColor = "\u001b[34m" // Blue
|
||||||
|
case "PUT":
|
||||||
|
methodColor = "\u001b[33m" // Yellow
|
||||||
|
case "DELETE":
|
||||||
|
methodColor = "\u001b[31m" // Red
|
||||||
|
default:
|
||||||
|
methodColor = "\u001b[35m" // Magenta for others
|
||||||
|
}
|
||||||
|
|
||||||
|
resetColor = "\u001b[0m"
|
||||||
|
|
||||||
|
// Format duration
|
||||||
|
var durationStr string
|
||||||
|
if duration.Milliseconds() < 1 {
|
||||||
|
durationStr = fmt.Sprintf("%.2fµs", float64(duration.Microseconds()))
|
||||||
|
} else if duration.Milliseconds() < 1000 {
|
||||||
|
durationStr = fmt.Sprintf("%.2fms", float64(duration.Microseconds())/1000)
|
||||||
|
} else {
|
||||||
|
durationStr = fmt.Sprintf("%.2fs", duration.Seconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log with colors
|
||||||
|
logger.Server("%s%d%s %s%s%s %s %s",
|
||||||
|
statusColor, statusCode, resetColor,
|
||||||
|
methodColor, method, resetColor,
|
||||||
|
path, durationStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryToLua converts HTTP query args to a Lua-friendly map
|
||||||
|
func QueryToLua(ctx *fasthttp.RequestCtx) map[string]any {
|
||||||
|
queryMap := make(map[string]any)
|
||||||
|
|
||||||
|
// Visit all query parameters
|
||||||
|
ctx.QueryArgs().VisitAll(func(key, value []byte) {
|
||||||
|
// Convert to string
|
||||||
|
k := string(key)
|
||||||
|
v := string(value)
|
||||||
|
|
||||||
|
// Check if this key already exists as an array
|
||||||
|
if existing, ok := queryMap[k]; ok {
|
||||||
|
// If it's already an array, append to it
|
||||||
|
if arr, ok := existing.([]string); ok {
|
||||||
|
queryMap[k] = append(arr, v)
|
||||||
|
} else if str, ok := existing.(string); ok {
|
||||||
|
// Convert existing string to array and append new value
|
||||||
|
queryMap[k] = []string{str, v}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// New key, store as string
|
||||||
|
queryMap[k] = v
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return queryMap
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseForm extracts form data from a request
|
||||||
|
func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
|
||||||
|
formData := make(map[string]any)
|
||||||
|
|
||||||
|
// Check if multipart form
|
||||||
|
if strings.Contains(string(ctx.Request.Header.ContentType()), "multipart/form-data") {
|
||||||
|
return parseMultipartForm(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regular form
|
||||||
|
ctx.PostArgs().VisitAll(func(key, value []byte) {
|
||||||
|
k := string(key)
|
||||||
|
v := string(value)
|
||||||
|
|
||||||
|
// Check if this key already exists
|
||||||
|
if existing, ok := formData[k]; ok {
|
||||||
|
// If it's already an array, append to it
|
||||||
|
if arr, ok := existing.([]string); ok {
|
||||||
|
formData[k] = append(arr, v)
|
||||||
|
} else if str, ok := existing.(string); ok {
|
||||||
|
// Convert existing string to array and append new value
|
||||||
|
formData[k] = []string{str, v}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// New key, store as string
|
||||||
|
formData[k] = v
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return formData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseMultipartForm handles multipart/form-data requests
|
||||||
|
func parseMultipartForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
|
||||||
|
formData := make(map[string]any)
|
||||||
|
|
||||||
|
// Parse multipart form
|
||||||
|
form, err := ctx.MultipartForm()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process form values
|
||||||
|
for key, values := range form.Value {
|
||||||
|
if len(values) == 1 {
|
||||||
|
formData[key] = values[0]
|
||||||
|
} else if len(values) > 1 {
|
||||||
|
formData[key] = values
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process files (store file info, not the content)
|
||||||
|
if len(form.File) > 0 {
|
||||||
|
files := make(map[string]any)
|
||||||
|
|
||||||
|
for fieldName, fileHeaders := range form.File {
|
||||||
|
if len(fileHeaders) == 1 {
|
||||||
|
files[fieldName] = fileInfoToMap(fileHeaders[0])
|
||||||
|
} else if len(fileHeaders) > 1 {
|
||||||
|
fileInfos := make([]map[string]any, 0, len(fileHeaders))
|
||||||
|
for _, fh := range fileHeaders {
|
||||||
|
fileInfos = append(fileInfos, fileInfoToMap(fh))
|
||||||
|
}
|
||||||
|
files[fieldName] = fileInfos
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
formData["_files"] = files
|
||||||
|
}
|
||||||
|
|
||||||
|
return formData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// fileInfoToMap converts a FileHeader to a map for Lua
|
||||||
|
func fileInfoToMap(fh *multipart.FileHeader) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"filename": fh.Filename,
|
||||||
|
"size": fh.Size,
|
||||||
|
"mimetype": getMimeType(fh),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMimeType gets the mime type from a file header
|
||||||
|
func getMimeType(fh *multipart.FileHeader) string {
|
||||||
|
if fh.Header != nil {
|
||||||
|
contentType := fh.Header.Get("Content-Type")
|
||||||
|
if contentType != "" {
|
||||||
|
return contentType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to basic type detection from filename
|
||||||
|
if strings.HasSuffix(fh.Filename, ".pdf") {
|
||||||
|
return "application/pdf"
|
||||||
|
} else if strings.HasSuffix(fh.Filename, ".png") {
|
||||||
|
return "image/png"
|
||||||
|
} else if strings.HasSuffix(fh.Filename, ".jpg") || strings.HasSuffix(fh.Filename, ".jpeg") {
|
||||||
|
return "image/jpeg"
|
||||||
|
} else if strings.HasSuffix(fh.Filename, ".gif") {
|
||||||
|
return "image/gif"
|
||||||
|
} else if strings.HasSuffix(fh.Filename, ".svg") {
|
||||||
|
return "image/svg+xml"
|
||||||
|
}
|
||||||
|
|
||||||
|
return "application/octet-stream"
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateSecureToken creates a cryptographically secure random token
|
||||||
|
func GenerateSecureToken(length int) (string, error) {
|
||||||
|
b := make([]byte, length)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.URLEncoding.EncodeToString(b)[:length], nil
|
||||||
|
}
|
|
@ -3,8 +3,6 @@ package runner
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"Moonshark/core/sessions"
|
|
||||||
|
|
||||||
"github.com/valyala/bytebufferpool"
|
"github.com/valyala/bytebufferpool"
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
@ -17,9 +15,9 @@ type Context struct {
|
||||||
// FastHTTP context if this was created from an HTTP request
|
// FastHTTP context if this was created from an HTTP request
|
||||||
RequestCtx *fasthttp.RequestCtx
|
RequestCtx *fasthttp.RequestCtx
|
||||||
|
|
||||||
// Session data and management
|
// Session information
|
||||||
Session *sessions.Session
|
SessionID string
|
||||||
SessionModified bool
|
SessionData map[string]any
|
||||||
|
|
||||||
// Buffer for efficient string operations
|
// Buffer for efficient string operations
|
||||||
buffer *bytebufferpool.ByteBuffer
|
buffer *bytebufferpool.ByteBuffer
|
||||||
|
@ -29,7 +27,8 @@ type Context struct {
|
||||||
var contextPool = sync.Pool{
|
var contextPool = sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
return &Context{
|
return &Context{
|
||||||
Values: make(map[string]any, 16),
|
Values: make(map[string]any, 16),
|
||||||
|
SessionData: make(map[string]any, 8),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -43,6 +42,44 @@ func NewContext() *Context {
|
||||||
func NewHTTPContext(requestCtx *fasthttp.RequestCtx) *Context {
|
func NewHTTPContext(requestCtx *fasthttp.RequestCtx) *Context {
|
||||||
ctx := NewContext()
|
ctx := NewContext()
|
||||||
ctx.RequestCtx = requestCtx
|
ctx.RequestCtx = requestCtx
|
||||||
|
|
||||||
|
// Extract common HTTP values that Lua might need
|
||||||
|
if requestCtx != nil {
|
||||||
|
ctx.Values["_request_method"] = string(requestCtx.Method())
|
||||||
|
ctx.Values["_request_path"] = string(requestCtx.Path())
|
||||||
|
ctx.Values["_request_url"] = string(requestCtx.RequestURI())
|
||||||
|
|
||||||
|
// Extract cookies
|
||||||
|
cookies := make(map[string]any)
|
||||||
|
requestCtx.Request.Header.VisitAllCookie(func(key, value []byte) {
|
||||||
|
cookies[string(key)] = string(value)
|
||||||
|
})
|
||||||
|
ctx.Values["_request_cookies"] = cookies
|
||||||
|
|
||||||
|
// Extract query params
|
||||||
|
query := make(map[string]any)
|
||||||
|
requestCtx.QueryArgs().VisitAll(func(key, value []byte) {
|
||||||
|
query[string(key)] = string(value)
|
||||||
|
})
|
||||||
|
ctx.Values["_request_query"] = query
|
||||||
|
|
||||||
|
// Extract form data if present
|
||||||
|
if requestCtx.IsPost() || requestCtx.IsPut() {
|
||||||
|
form := make(map[string]any)
|
||||||
|
requestCtx.PostArgs().VisitAll(func(key, value []byte) {
|
||||||
|
form[string(key)] = string(value)
|
||||||
|
})
|
||||||
|
ctx.Values["_request_form"] = form
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract headers
|
||||||
|
headers := make(map[string]any)
|
||||||
|
requestCtx.Request.Header.VisitAll(func(key, value []byte) {
|
||||||
|
headers[string(key)] = string(value)
|
||||||
|
})
|
||||||
|
ctx.Values["_request_headers"] = headers
|
||||||
|
}
|
||||||
|
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,9 +90,12 @@ func (c *Context) Release() {
|
||||||
delete(c.Values, k)
|
delete(c.Values, k)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for k := range c.SessionData {
|
||||||
|
delete(c.SessionData, k)
|
||||||
|
}
|
||||||
|
|
||||||
// Reset session info
|
// Reset session info
|
||||||
c.Session = nil
|
c.SessionID = ""
|
||||||
c.SessionModified = false
|
|
||||||
|
|
||||||
// Reset request context
|
// Reset request context
|
||||||
c.RequestCtx = nil
|
c.RequestCtx = nil
|
||||||
|
@ -87,13 +127,12 @@ func (c *Context) Get(key string) any {
|
||||||
return c.Values[key]
|
return c.Values[key]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Contains checks if a key exists in the context
|
// SetSession sets a session data value
|
||||||
func (c *Context) Contains(key string) bool {
|
func (c *Context) SetSession(key string, value any) {
|
||||||
_, exists := c.Values[key]
|
c.SessionData[key] = value
|
||||||
return exists
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete removes a value from the context
|
// GetSession retrieves a session data value
|
||||||
func (c *Context) Delete(key string) {
|
func (c *Context) GetSession(key string) any {
|
||||||
delete(c.Values, key)
|
return c.SessionData[key]
|
||||||
}
|
}
|
|
@ -1,262 +0,0 @@
|
||||||
package runner
|
|
||||||
|
|
||||||
import (
|
|
||||||
"Moonshark/core/runner/sandbox"
|
|
||||||
"Moonshark/core/utils/logger"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CoreModuleRegistry manages the initialization and reloading of core modules
|
|
||||||
type CoreModuleRegistry struct {
|
|
||||||
modules map[string]sandbox.StateInitFunc // Module initializers
|
|
||||||
initOrder []string // Explicit initialization order
|
|
||||||
dependencies map[string][]string // Module dependencies
|
|
||||||
initializedFlag map[string]bool // Track which modules are initialized
|
|
||||||
mu sync.RWMutex
|
|
||||||
debug bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewCoreModuleRegistry creates a new core module registry
|
|
||||||
func NewCoreModuleRegistry() *CoreModuleRegistry {
|
|
||||||
return &CoreModuleRegistry{
|
|
||||||
modules: make(map[string]sandbox.StateInitFunc),
|
|
||||||
initOrder: []string{},
|
|
||||||
dependencies: make(map[string][]string),
|
|
||||||
initializedFlag: make(map[string]bool),
|
|
||||||
debug: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// EnableDebug turns on debug logging
|
|
||||||
func (r *CoreModuleRegistry) EnableDebug() {
|
|
||||||
r.debug = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// debugLog prints debug messages if enabled
|
|
||||||
func (r *CoreModuleRegistry) debugLog(format string, args ...interface{}) {
|
|
||||||
if r.debug {
|
|
||||||
logger.Debug("CoreRegistry "+format, args...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register adds a module to the registry
|
|
||||||
func (r *CoreModuleRegistry) Register(name string, initFunc sandbox.StateInitFunc) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
|
|
||||||
r.modules[name] = initFunc
|
|
||||||
|
|
||||||
// Add to initialization order if not already there
|
|
||||||
for _, n := range r.initOrder {
|
|
||||||
if n == name {
|
|
||||||
return // Already registered
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.initOrder = append(r.initOrder, name)
|
|
||||||
r.debugLog("registered module %s", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterWithDependencies registers a module with explicit dependencies
|
|
||||||
func (r *CoreModuleRegistry) RegisterWithDependencies(name string, initFunc sandbox.StateInitFunc, dependencies []string) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
|
|
||||||
r.modules[name] = initFunc
|
|
||||||
r.dependencies[name] = dependencies
|
|
||||||
|
|
||||||
// Add to initialization order if not already there
|
|
||||||
for _, n := range r.initOrder {
|
|
||||||
if n == name {
|
|
||||||
return // Already registered
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.initOrder = append(r.initOrder, name)
|
|
||||||
r.debugLog("registered module %s with dependencies: %v", name, dependencies)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetInitOrder sets explicit initialization order
|
|
||||||
func (r *CoreModuleRegistry) SetInitOrder(order []string) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
|
|
||||||
// Create new init order
|
|
||||||
newOrder := make([]string, 0, len(order))
|
|
||||||
|
|
||||||
// First add all known modules that are in the specified order
|
|
||||||
for _, name := range order {
|
|
||||||
if _, exists := r.modules[name]; exists && !contains(newOrder, name) {
|
|
||||||
newOrder = append(newOrder, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then add any modules not in the specified order
|
|
||||||
for name := range r.modules {
|
|
||||||
if !contains(newOrder, name) {
|
|
||||||
newOrder = append(newOrder, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.initOrder = newOrder
|
|
||||||
r.debugLog("Set initialization order: %v", r.initOrder)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize initializes all registered modules
|
|
||||||
func (r *CoreModuleRegistry) Initialize(state *luajit.State, stateIndex int) error {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
verbose := stateIndex == 0
|
|
||||||
if verbose {
|
|
||||||
r.debugLog("initializing %d modules...", len(r.initOrder))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear initialization flags
|
|
||||||
r.initializedFlag = make(map[string]bool)
|
|
||||||
|
|
||||||
// Initialize modules in order, respecting dependencies
|
|
||||||
for _, name := range r.initOrder {
|
|
||||||
if err := r.initializeModule(state, name, []string{}, verbose); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if verbose {
|
|
||||||
r.debugLog("All modules initialized successfully")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// initializeModule initializes a module and its dependencies
|
|
||||||
func (r *CoreModuleRegistry) initializeModule(state *luajit.State, name string,
|
|
||||||
initStack []string, verbose bool) error {
|
|
||||||
// Check if already initialized
|
|
||||||
if r.initializedFlag[name] {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for circular dependencies
|
|
||||||
for _, n := range initStack {
|
|
||||||
if n == name {
|
|
||||||
return fmt.Errorf("circular dependency detected: %s -> %s",
|
|
||||||
strings.Join(initStack, " -> "), name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get init function
|
|
||||||
initFunc, ok := r.modules[name]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("module not found: %s", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize dependencies first
|
|
||||||
deps := r.dependencies[name]
|
|
||||||
if len(deps) > 0 {
|
|
||||||
newStack := append(initStack, name)
|
|
||||||
for _, dep := range deps {
|
|
||||||
if err := r.initializeModule(state, dep, newStack, verbose); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err := initFunc(state)
|
|
||||||
if err != nil {
|
|
||||||
// Always log failures regardless of verbose setting
|
|
||||||
r.debugLog("Initializing module %s... failure: %v", name, err)
|
|
||||||
return fmt.Errorf("failed to initialize module %s: %w", name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.initializedFlag[name] = true
|
|
||||||
|
|
||||||
if verbose {
|
|
||||||
r.debugLog("Initializing module %s... success", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitializeModule initializes a specific module
|
|
||||||
func (r *CoreModuleRegistry) InitializeModule(state *luajit.State, name string) error {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
// Clear initialization flag for this module
|
|
||||||
r.initializedFlag[name] = false
|
|
||||||
|
|
||||||
// Always use verbose logging for explicit module initialization
|
|
||||||
return r.initializeModule(state, name, []string{}, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MatchModuleName checks if a file path corresponds to a registered module
|
|
||||||
func (r *CoreModuleRegistry) MatchModuleName(modName string) (string, bool) {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
// Exact match
|
|
||||||
if _, ok := r.modules[modName]; ok {
|
|
||||||
return modName, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the module name ends with a registered module
|
|
||||||
for name := range r.modules {
|
|
||||||
if strings.HasSuffix(modName, "."+name) {
|
|
||||||
return name, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Global registry instance
|
|
||||||
var GlobalRegistry = NewCoreModuleRegistry()
|
|
||||||
|
|
||||||
// Initialize global registry with core modules
|
|
||||||
func init() {
|
|
||||||
GlobalRegistry.EnableDebug() // Enable debugging by default
|
|
||||||
logger.Debug("[ModuleRegistry] Registering core modules...")
|
|
||||||
|
|
||||||
// Register core modules
|
|
||||||
GlobalRegistry.Register("util", func(state *luajit.State) error {
|
|
||||||
return sandbox.UtilModuleInitFunc()(state)
|
|
||||||
})
|
|
||||||
|
|
||||||
GlobalRegistry.Register("http", func(state *luajit.State) error {
|
|
||||||
return sandbox.HTTPModuleInitFunc()(state)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Set explicit initialization order
|
|
||||||
GlobalRegistry.SetInitOrder([]string{
|
|
||||||
"util", // First: core utilities
|
|
||||||
"http", // Second: HTTP functionality
|
|
||||||
"session", // Third: Session functionality
|
|
||||||
"csrf", // Fourth: CSRF protection
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.Debug("Core modules registered successfully")
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterCoreModule registers a core module with the global registry
|
|
||||||
func RegisterCoreModule(name string, initFunc sandbox.StateInitFunc) {
|
|
||||||
GlobalRegistry.Register(name, initFunc)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterCoreModuleWithDependencies registers a module with dependencies
|
|
||||||
func RegisterCoreModuleWithDependencies(name string, initFunc sandbox.StateInitFunc, dependencies []string) {
|
|
||||||
GlobalRegistry.RegisterWithDependencies(name, initFunc, dependencies)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper functions
|
|
||||||
func contains(slice []string, item string) bool {
|
|
||||||
for _, s := range slice {
|
|
||||||
if s == item {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
61
core/runner/Embed.go
Normal file
61
core/runner/Embed.go
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"Moonshark/core/utils/logger"
|
||||||
|
|
||||||
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed sandbox.lua
|
||||||
|
var sandboxLuaCode string
|
||||||
|
|
||||||
|
// Global bytecode cache to improve performance
|
||||||
|
var (
|
||||||
|
sandboxBytecode atomic.Pointer[[]byte]
|
||||||
|
bytecodeOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
// precompileSandboxCode compiles the sandbox.lua code to bytecode once
|
||||||
|
func precompileSandboxCode() {
|
||||||
|
// Create temporary state for compilation
|
||||||
|
tempState := luajit.New()
|
||||||
|
if tempState == nil {
|
||||||
|
logger.Error("Failed to create temp Lua state for bytecode compilation")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer tempState.Close()
|
||||||
|
defer tempState.Cleanup()
|
||||||
|
|
||||||
|
code, err := tempState.CompileBytecode(sandboxLuaCode, "sandbox.lua")
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to compile sandbox code: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
bytecode := make([]byte, len(code))
|
||||||
|
copy(bytecode, code)
|
||||||
|
sandboxBytecode.Store(&bytecode)
|
||||||
|
|
||||||
|
logger.Debug("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(code))
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadSandboxIntoState loads the sandbox code into a Lua state
|
||||||
|
func loadSandboxIntoState(state *luajit.State) error {
|
||||||
|
// Initialize bytecode once
|
||||||
|
bytecodeOnce.Do(precompileSandboxCode)
|
||||||
|
|
||||||
|
// Use precompiled bytecode if available
|
||||||
|
bytecode := sandboxBytecode.Load()
|
||||||
|
if bytecode != nil && len(*bytecode) > 0 {
|
||||||
|
logger.Debug("Loading sandbox.lua from precompiled bytecode")
|
||||||
|
return state.LoadAndRunBytecode(*bytecode, "sandbox.lua")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to direct execution
|
||||||
|
logger.Warning("Using non-precompiled sandbox.lua (bytecode compilation failed)")
|
||||||
|
return state.DoString(sandboxLuaCode)
|
||||||
|
}
|
334
core/runner/Http.go
Normal file
334
core/runner/Http.go
Normal file
|
@ -0,0 +1,334 @@
|
||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/goccy/go-json"
|
||||||
|
"github.com/valyala/bytebufferpool"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
|
"Moonshark/core/utils/logger"
|
||||||
|
|
||||||
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Default HTTP client with sensible timeout
|
||||||
|
var defaultFastClient = fasthttp.Client{
|
||||||
|
MaxConnsPerHost: 1024,
|
||||||
|
MaxIdleConnDuration: time.Minute,
|
||||||
|
ReadTimeout: 30 * time.Second,
|
||||||
|
WriteTimeout: 30 * time.Second,
|
||||||
|
DisableHeaderNamesNormalizing: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPClientConfig contains client settings
|
||||||
|
type HTTPClientConfig struct {
|
||||||
|
MaxTimeout time.Duration // Maximum timeout for requests (0 = no limit)
|
||||||
|
DefaultTimeout time.Duration // Default request timeout
|
||||||
|
MaxResponseSize int64 // Maximum response size in bytes (0 = no limit)
|
||||||
|
AllowRemote bool // Whether to allow remote connections
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultHTTPClientConfig provides sensible defaults
|
||||||
|
var DefaultHTTPClientConfig = HTTPClientConfig{
|
||||||
|
MaxTimeout: 60 * time.Second,
|
||||||
|
DefaultTimeout: 30 * time.Second,
|
||||||
|
MaxResponseSize: 10 * 1024 * 1024, // 10MB
|
||||||
|
AllowRemote: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyResponse applies a Response to a fasthttp.RequestCtx
|
||||||
|
func ApplyResponse(resp *Response, ctx *fasthttp.RequestCtx) {
|
||||||
|
// Set status code
|
||||||
|
ctx.SetStatusCode(resp.Status)
|
||||||
|
|
||||||
|
// Set headers
|
||||||
|
for name, value := range resp.Headers {
|
||||||
|
ctx.Response.Header.Set(name, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set cookies
|
||||||
|
for _, cookie := range resp.Cookies {
|
||||||
|
ctx.Response.Header.SetCookie(cookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the body based on its type
|
||||||
|
if resp.Body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get a buffer from the pool
|
||||||
|
buf := bytebufferpool.Get()
|
||||||
|
defer bytebufferpool.Put(buf)
|
||||||
|
|
||||||
|
// Set body based on type
|
||||||
|
switch body := resp.Body.(type) {
|
||||||
|
case string:
|
||||||
|
ctx.SetBodyString(body)
|
||||||
|
case []byte:
|
||||||
|
ctx.SetBody(body)
|
||||||
|
case map[string]any, []any, []float64, []string, []int:
|
||||||
|
// Marshal JSON
|
||||||
|
if err := json.NewEncoder(buf).Encode(body); err == nil {
|
||||||
|
// Set content type if not already set
|
||||||
|
if len(ctx.Response.Header.ContentType()) == 0 {
|
||||||
|
ctx.Response.Header.SetContentType("application/json")
|
||||||
|
}
|
||||||
|
ctx.SetBody(buf.Bytes())
|
||||||
|
} else {
|
||||||
|
// Fallback
|
||||||
|
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Default to string representation
|
||||||
|
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// httpRequest makes an HTTP request and returns the result to Lua
|
||||||
|
func httpRequest(state *luajit.State) int {
|
||||||
|
// Get method (required)
|
||||||
|
if !state.IsString(1) {
|
||||||
|
state.PushString("http.client.request: method must be a string")
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
method := strings.ToUpper(state.ToString(1))
|
||||||
|
|
||||||
|
// Get URL (required)
|
||||||
|
if !state.IsString(2) {
|
||||||
|
state.PushString("http.client.request: url must be a string")
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
urlStr := state.ToString(2)
|
||||||
|
|
||||||
|
// Parse URL to check if it's valid
|
||||||
|
parsedURL, err := url.Parse(urlStr)
|
||||||
|
if err != nil {
|
||||||
|
state.PushString("Invalid URL: " + err.Error())
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get client configuration
|
||||||
|
config := DefaultHTTPClientConfig
|
||||||
|
|
||||||
|
// Check if remote connections are allowed
|
||||||
|
if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") {
|
||||||
|
state.PushString("Remote connections are not allowed")
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use bytebufferpool for request and response
|
||||||
|
req := fasthttp.AcquireRequest()
|
||||||
|
resp := fasthttp.AcquireResponse()
|
||||||
|
defer fasthttp.ReleaseRequest(req)
|
||||||
|
defer fasthttp.ReleaseResponse(resp)
|
||||||
|
|
||||||
|
// Set up request
|
||||||
|
req.Header.SetMethod(method)
|
||||||
|
req.SetRequestURI(urlStr)
|
||||||
|
req.Header.Set("User-Agent", "Moonshark/1.0")
|
||||||
|
|
||||||
|
// Get body (optional)
|
||||||
|
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||||
|
if state.IsString(3) {
|
||||||
|
// String body
|
||||||
|
req.SetBodyString(state.ToString(3))
|
||||||
|
} else if state.IsTable(3) {
|
||||||
|
// Table body - convert to JSON
|
||||||
|
luaTable, err := state.ToTable(3)
|
||||||
|
if err != nil {
|
||||||
|
state.PushString("Failed to parse body table: " + err.Error())
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use bytebufferpool for JSON serialization
|
||||||
|
buf := bytebufferpool.Get()
|
||||||
|
defer bytebufferpool.Put(buf)
|
||||||
|
|
||||||
|
if err := json.NewEncoder(buf).Encode(luaTable); err != nil {
|
||||||
|
state.PushString("Failed to convert body to JSON: " + err.Error())
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
req.SetBody(buf.Bytes())
|
||||||
|
req.Header.SetContentType("application/json")
|
||||||
|
} else {
|
||||||
|
state.PushString("Body must be a string or table")
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process options (headers, timeout, etc.)
|
||||||
|
timeout := config.DefaultTimeout
|
||||||
|
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsTable(4) {
|
||||||
|
// Process headers
|
||||||
|
state.GetField(4, "headers")
|
||||||
|
if state.IsTable(-1) {
|
||||||
|
// Iterate through headers
|
||||||
|
state.PushNil() // Start iteration
|
||||||
|
for state.Next(-2) {
|
||||||
|
// Stack now has key at -2 and value at -1
|
||||||
|
if state.IsString(-2) && state.IsString(-1) {
|
||||||
|
headerName := state.ToString(-2)
|
||||||
|
headerValue := state.ToString(-1)
|
||||||
|
req.Header.Set(headerName, headerValue)
|
||||||
|
}
|
||||||
|
state.Pop(1) // Pop value, leave key for next iteration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.Pop(1) // Pop headers table
|
||||||
|
|
||||||
|
// Get timeout
|
||||||
|
state.GetField(4, "timeout")
|
||||||
|
if state.IsNumber(-1) {
|
||||||
|
requestTimeout := time.Duration(state.ToNumber(-1)) * time.Second
|
||||||
|
|
||||||
|
// Apply max timeout if configured
|
||||||
|
if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout {
|
||||||
|
timeout = config.MaxTimeout
|
||||||
|
} else {
|
||||||
|
timeout = requestTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.Pop(1) // Pop timeout
|
||||||
|
|
||||||
|
// Process query parameters
|
||||||
|
state.GetField(4, "query")
|
||||||
|
if state.IsTable(-1) {
|
||||||
|
// Create URL args
|
||||||
|
args := req.URI().QueryArgs()
|
||||||
|
|
||||||
|
// Iterate through query params
|
||||||
|
state.PushNil() // Start iteration
|
||||||
|
for state.Next(-2) {
|
||||||
|
if state.IsString(-2) {
|
||||||
|
paramName := state.ToString(-2)
|
||||||
|
|
||||||
|
// Handle different value types
|
||||||
|
if state.IsString(-1) {
|
||||||
|
args.Add(paramName, state.ToString(-1))
|
||||||
|
} else if state.IsNumber(-1) {
|
||||||
|
args.Add(paramName, strings.TrimRight(strings.TrimRight(
|
||||||
|
state.ToString(-1), "0"), "."))
|
||||||
|
} else if state.IsBoolean(-1) {
|
||||||
|
if state.ToBoolean(-1) {
|
||||||
|
args.Add(paramName, "true")
|
||||||
|
} else {
|
||||||
|
args.Add(paramName, "false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.Pop(1) // Pop value, leave key for next iteration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.Pop(1) // Pop query table
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create context with timeout
|
||||||
|
_, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Execute request
|
||||||
|
err = defaultFastClient.DoTimeout(req, resp, timeout)
|
||||||
|
if err != nil {
|
||||||
|
errStr := "Request failed: " + err.Error()
|
||||||
|
if errors.Is(err, fasthttp.ErrTimeout) {
|
||||||
|
errStr = "Request timed out after " + timeout.String()
|
||||||
|
}
|
||||||
|
state.PushString(errStr)
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create response table
|
||||||
|
state.NewTable()
|
||||||
|
|
||||||
|
// Set status code
|
||||||
|
state.PushNumber(float64(resp.StatusCode()))
|
||||||
|
state.SetField(-2, "status")
|
||||||
|
|
||||||
|
// Set status text
|
||||||
|
statusText := fasthttp.StatusMessage(resp.StatusCode())
|
||||||
|
state.PushString(statusText)
|
||||||
|
state.SetField(-2, "status_text")
|
||||||
|
|
||||||
|
// Set body
|
||||||
|
var respBody []byte
|
||||||
|
|
||||||
|
// Apply size limits to response
|
||||||
|
if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize {
|
||||||
|
// Make a limited copy
|
||||||
|
respBody = make([]byte, config.MaxResponseSize)
|
||||||
|
copy(respBody, resp.Body())
|
||||||
|
} else {
|
||||||
|
respBody = resp.Body()
|
||||||
|
}
|
||||||
|
|
||||||
|
state.PushString(string(respBody))
|
||||||
|
state.SetField(-2, "body")
|
||||||
|
|
||||||
|
// Parse body as JSON if content type is application/json
|
||||||
|
contentType := string(resp.Header.ContentType())
|
||||||
|
if strings.Contains(contentType, "application/json") {
|
||||||
|
var jsonData any
|
||||||
|
if err := json.Unmarshal(respBody, &jsonData); err == nil {
|
||||||
|
if err := state.PushValue(jsonData); err == nil {
|
||||||
|
state.SetField(-2, "json")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set headers
|
||||||
|
state.NewTable()
|
||||||
|
resp.Header.VisitAll(func(key, value []byte) {
|
||||||
|
state.PushString(string(value))
|
||||||
|
state.SetField(-2, string(key))
|
||||||
|
})
|
||||||
|
state.SetField(-2, "headers")
|
||||||
|
|
||||||
|
// Create ok field (true if status code is 2xx)
|
||||||
|
state.PushBoolean(resp.StatusCode() >= 200 && resp.StatusCode() < 300)
|
||||||
|
state.SetField(-2, "ok")
|
||||||
|
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateToken creates a cryptographically secure random token
|
||||||
|
func generateToken(state *luajit.State) int {
|
||||||
|
// Get the length from the Lua arguments (default to 32)
|
||||||
|
length := 32
|
||||||
|
if state.GetTop() >= 1 && state.IsNumber(1) {
|
||||||
|
length = int(state.ToNumber(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enforce minimum length for security
|
||||||
|
if length < 16 {
|
||||||
|
length = 16
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate secure random bytes
|
||||||
|
tokenBytes := make([]byte, length)
|
||||||
|
if _, err := rand.Read(tokenBytes); err != nil {
|
||||||
|
logger.Error("Failed to generate secure token: %v", err)
|
||||||
|
state.PushString("")
|
||||||
|
return 1 // Return empty string on error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode as base64
|
||||||
|
token := base64.RawURLEncoding.EncodeToString(tokenBytes)
|
||||||
|
|
||||||
|
// Trim to requested length (base64 might be longer)
|
||||||
|
if len(token) > length {
|
||||||
|
token = token[:length]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Push the token to the Lua stack
|
||||||
|
state.PushString(token)
|
||||||
|
return 1 // One return value
|
||||||
|
}
|
|
@ -6,6 +6,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"Moonshark/core/utils/logger"
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -15,61 +17,15 @@ type ModuleConfig struct {
|
||||||
LibDirs []string // Additional library directories
|
LibDirs []string // Additional library directories
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModuleInfo stores information about a loaded module
|
|
||||||
type ModuleInfo struct {
|
|
||||||
Name string
|
|
||||||
Path string
|
|
||||||
IsCore bool
|
|
||||||
Bytecode []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// ModuleLoader manages module loading and caching
|
// ModuleLoader manages module loading and caching
|
||||||
type ModuleLoader struct {
|
type ModuleLoader struct {
|
||||||
config *ModuleConfig
|
config *ModuleConfig
|
||||||
registry *ModuleRegistry
|
|
||||||
pathCache map[string]string // Cache module paths for fast lookups
|
pathCache map[string]string // Cache module paths for fast lookups
|
||||||
bytecodeCache map[string][]byte // Cache of compiled bytecode
|
bytecodeCache map[string][]byte // Cache of compiled bytecode
|
||||||
debug bool
|
debug bool
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModuleRegistry keeps track of Lua modules for file watching
|
|
||||||
type ModuleRegistry struct {
|
|
||||||
// Maps file paths to module names
|
|
||||||
pathToModule sync.Map
|
|
||||||
// Maps module names to file paths
|
|
||||||
moduleToPath sync.Map
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewModuleRegistry creates a new module registry
|
|
||||||
func NewModuleRegistry() *ModuleRegistry {
|
|
||||||
return &ModuleRegistry{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register adds a module path to the registry
|
|
||||||
func (r *ModuleRegistry) Register(path string, name string) {
|
|
||||||
r.pathToModule.Store(path, name)
|
|
||||||
r.moduleToPath.Store(name, path)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetModuleName retrieves a module name by path
|
|
||||||
func (r *ModuleRegistry) GetModuleName(path string) (string, bool) {
|
|
||||||
value, ok := r.pathToModule.Load(path)
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
return value.(string), true
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetModulePath retrieves a path by module name
|
|
||||||
func (r *ModuleRegistry) GetModulePath(name string) (string, bool) {
|
|
||||||
value, ok := r.moduleToPath.Load(name)
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
return value.(string), true
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewModuleLoader creates a new module loader
|
// NewModuleLoader creates a new module loader
|
||||||
func NewModuleLoader(config *ModuleConfig) *ModuleLoader {
|
func NewModuleLoader(config *ModuleConfig) *ModuleLoader {
|
||||||
if config == nil {
|
if config == nil {
|
||||||
|
@ -81,7 +37,6 @@ func NewModuleLoader(config *ModuleConfig) *ModuleLoader {
|
||||||
|
|
||||||
return &ModuleLoader{
|
return &ModuleLoader{
|
||||||
config: config,
|
config: config,
|
||||||
registry: NewModuleRegistry(),
|
|
||||||
pathCache: make(map[string]string),
|
pathCache: make(map[string]string),
|
||||||
bytecodeCache: make(map[string][]byte),
|
bytecodeCache: make(map[string][]byte),
|
||||||
debug: false,
|
debug: false,
|
||||||
|
@ -100,6 +55,13 @@ func (l *ModuleLoader) SetScriptDir(dir string) {
|
||||||
l.config.ScriptDir = dir
|
l.config.ScriptDir = dir
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// debugLog logs a message if debug mode is enabled
|
||||||
|
func (l *ModuleLoader) debugLog(format string, args ...interface{}) {
|
||||||
|
if l.debug {
|
||||||
|
logger.Debug("ModuleLoader "+format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SetupRequire configures the require system in a Lua state
|
// SetupRequire configures the require system in a Lua state
|
||||||
func (l *ModuleLoader) SetupRequire(state *luajit.State) error {
|
func (l *ModuleLoader) SetupRequire(state *luajit.State) error {
|
||||||
l.mu.RLock()
|
l.mu.RLock()
|
||||||
|
@ -207,6 +169,8 @@ func (l *ModuleLoader) PreloadModules(state *luajit.State) error {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
l.debugLog("Scanning directory: %s", absDir)
|
||||||
|
|
||||||
// Find all Lua files
|
// Find all Lua files
|
||||||
err = filepath.Walk(absDir, func(path string, info os.FileInfo, err error) error {
|
err = filepath.Walk(absDir, func(path string, info os.FileInfo, err error) error {
|
||||||
if err != nil || info.IsDir() || !strings.HasSuffix(path, ".lua") {
|
if err != nil || info.IsDir() || !strings.HasSuffix(path, ".lua") {
|
||||||
|
@ -223,19 +187,22 @@ func (l *ModuleLoader) PreloadModules(state *luajit.State) error {
|
||||||
modName := strings.TrimSuffix(relPath, ".lua")
|
modName := strings.TrimSuffix(relPath, ".lua")
|
||||||
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
|
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
|
||||||
|
|
||||||
|
l.debugLog("Found module: %s at %s", modName, path)
|
||||||
|
|
||||||
// Register in our caches
|
// Register in our caches
|
||||||
l.pathCache[modName] = path
|
l.pathCache[modName] = path
|
||||||
l.registry.Register(path, modName)
|
|
||||||
|
|
||||||
// Load file content
|
// Load file content
|
||||||
content, err := os.ReadFile(path)
|
content, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
l.debugLog("Failed to read module file: %v", err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compile to bytecode
|
// Compile to bytecode
|
||||||
bytecode, err := state.CompileBytecode(string(content), path)
|
bytecode, err := state.CompileBytecode(string(content), path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
l.debugLog("Failed to compile module: %v", err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -354,10 +321,11 @@ func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) {
|
||||||
// Clean path for proper comparison
|
// Clean path for proper comparison
|
||||||
path = filepath.Clean(path)
|
path = filepath.Clean(path)
|
||||||
|
|
||||||
// Try direct lookup from registry
|
// Try direct lookup from cache
|
||||||
modName, found := l.registry.GetModuleName(path)
|
for modName, modPath := range l.pathCache {
|
||||||
if found {
|
if modPath == path {
|
||||||
return modName, true
|
return modName, true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to find by relative path from lib dirs
|
// Try to find by relative path from lib dirs
|
||||||
|
@ -373,7 +341,7 @@ func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasSuffix(relPath, ".lua") {
|
if strings.HasSuffix(relPath, ".lua") {
|
||||||
modName = strings.TrimSuffix(relPath, ".lua")
|
modName := strings.TrimSuffix(relPath, ".lua")
|
||||||
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
|
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
|
||||||
return modName, true
|
return modName, true
|
||||||
}
|
}
|
||||||
|
@ -382,103 +350,6 @@ func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReloadModule reloads a module from disk
|
|
||||||
func (l *ModuleLoader) ReloadModule(state *luajit.State, name string) (bool, error) {
|
|
||||||
l.mu.Lock()
|
|
||||||
defer l.mu.Unlock()
|
|
||||||
|
|
||||||
// Get module path
|
|
||||||
path, ok := l.registry.GetModulePath(name)
|
|
||||||
if !ok {
|
|
||||||
for modName, modPath := range l.pathCache {
|
|
||||||
if modName == name {
|
|
||||||
path = modPath
|
|
||||||
ok = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ok || path == "" {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Invalidate module in Lua
|
|
||||||
err := state.DoString(`
|
|
||||||
package.loaded["` + name + `"] = nil
|
|
||||||
__ready_modules["` + name + `"] = nil
|
|
||||||
if package.preload then
|
|
||||||
package.preload["` + name + `"] = nil
|
|
||||||
end
|
|
||||||
`)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if file still exists
|
|
||||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
|
||||||
// File was deleted, just invalidate
|
|
||||||
delete(l.pathCache, name)
|
|
||||||
delete(l.bytecodeCache, name)
|
|
||||||
l.registry.moduleToPath.Delete(name)
|
|
||||||
l.registry.pathToModule.Delete(path)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read updated file
|
|
||||||
content, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compile to bytecode
|
|
||||||
bytecode, err := state.CompileBytecode(string(content), path)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update cache
|
|
||||||
l.bytecodeCache[name] = bytecode
|
|
||||||
|
|
||||||
// Load bytecode into state
|
|
||||||
if err := state.LoadBytecode(bytecode, path); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update preload
|
|
||||||
luaCode := `
|
|
||||||
local modname = "` + name + `"
|
|
||||||
package.loaded[modname] = nil
|
|
||||||
package.preload[modname] = ...
|
|
||||||
__ready_modules[modname] = true
|
|
||||||
`
|
|
||||||
|
|
||||||
if err := state.DoString(luaCode); err != nil {
|
|
||||||
state.Pop(1) // Remove chunk from stack
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
state.Pop(1) // Remove chunk from stack
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetModules clears non-core modules from package.loaded
|
|
||||||
func (l *ModuleLoader) ResetModules(state *luajit.State) error {
|
|
||||||
return state.DoString(`
|
|
||||||
local core_modules = {
|
|
||||||
string = true, table = true, math = true, os = true,
|
|
||||||
package = true, io = true, coroutine = true, debug = true, _G = true
|
|
||||||
}
|
|
||||||
|
|
||||||
for name in pairs(package.loaded) do
|
|
||||||
if not core_modules[name] then
|
|
||||||
package.loaded[name] = nil
|
|
||||||
end
|
|
||||||
end
|
|
||||||
`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// escapeLuaString escapes special characters in a string for Lua
|
// escapeLuaString escapes special characters in a string for Lua
|
||||||
func escapeLuaString(s string) string {
|
func escapeLuaString(s string) string {
|
||||||
replacer := strings.NewReplacer(
|
replacer := strings.NewReplacer(
|
||||||
|
|
76
core/runner/Response.go
Normal file
76
core/runner/Response.go
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Response represents a unified response from script execution
|
||||||
|
type Response struct {
|
||||||
|
// Basic properties
|
||||||
|
Body any // Body content (any type)
|
||||||
|
Metadata map[string]any // Additional metadata
|
||||||
|
|
||||||
|
// HTTP specific properties
|
||||||
|
Status int // HTTP status code
|
||||||
|
Headers map[string]string // HTTP headers
|
||||||
|
Cookies []*fasthttp.Cookie // HTTP cookies
|
||||||
|
|
||||||
|
// Session information
|
||||||
|
SessionID string // Session ID
|
||||||
|
SessionData map[string]any // Session data
|
||||||
|
SessionModified bool // Whether session was modified
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response pool to reduce allocations
|
||||||
|
var responsePool = sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
return &Response{
|
||||||
|
Status: 200,
|
||||||
|
Headers: make(map[string]string, 8),
|
||||||
|
Metadata: make(map[string]any, 8),
|
||||||
|
Cookies: make([]*fasthttp.Cookie, 0, 4),
|
||||||
|
SessionData: make(map[string]any, 8),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewResponse creates a new response object from the pool
|
||||||
|
func NewResponse() *Response {
|
||||||
|
return responsePool.Get().(*Response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release returns a response to the pool after cleaning it
|
||||||
|
func ReleaseResponse(resp *Response) {
|
||||||
|
if resp == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset fields to default values
|
||||||
|
resp.Body = nil
|
||||||
|
resp.Status = 200
|
||||||
|
|
||||||
|
// Clear maps
|
||||||
|
for k := range resp.Headers {
|
||||||
|
delete(resp.Headers, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k := range resp.Metadata {
|
||||||
|
delete(resp.Metadata, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k := range resp.SessionData {
|
||||||
|
delete(resp.SessionData, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear cookies
|
||||||
|
resp.Cookies = resp.Cookies[:0]
|
||||||
|
|
||||||
|
// Reset session info
|
||||||
|
resp.SessionID = ""
|
||||||
|
resp.SessionModified = false
|
||||||
|
|
||||||
|
// Return to pool
|
||||||
|
responsePool.Put(resp)
|
||||||
|
}
|
|
@ -9,8 +9,6 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
luaCtx "Moonshark/core/runner/context"
|
|
||||||
"Moonshark/core/runner/sandbox"
|
|
||||||
"Moonshark/core/utils/logger"
|
"Moonshark/core/utils/logger"
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
|
@ -29,30 +27,22 @@ type RunnerOption func(*Runner)
|
||||||
|
|
||||||
// State wraps a Lua state with its sandbox
|
// State wraps a Lua state with its sandbox
|
||||||
type State struct {
|
type State struct {
|
||||||
L *luajit.State // The Lua state
|
L *luajit.State // The Lua state
|
||||||
sandbox *sandbox.Sandbox // Associated sandbox
|
sandbox *Sandbox // Associated sandbox
|
||||||
index int // Index for debugging
|
index int // Index for debugging
|
||||||
inUse bool // Whether the state is currently in use
|
inUse bool // Whether the state is currently in use
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitHook runs before executing a script
|
|
||||||
type InitHook func(*luajit.State, *luaCtx.Context) error
|
|
||||||
|
|
||||||
// FinalizeHook runs after executing a script
|
|
||||||
type FinalizeHook func(*luajit.State, *luaCtx.Context, any) error
|
|
||||||
|
|
||||||
// Runner runs Lua scripts using a pool of Lua states
|
// Runner runs Lua scripts using a pool of Lua states
|
||||||
type Runner struct {
|
type Runner struct {
|
||||||
states []*State // All states managed by this runner
|
states []*State // All states managed by this runner
|
||||||
statePool chan int // Pool of available state indexes
|
statePool chan int // Pool of available state indexes
|
||||||
poolSize int // Size of the state pool
|
poolSize int // Size of the state pool
|
||||||
moduleLoader *ModuleLoader // Module loader
|
moduleLoader *ModuleLoader // Module loader
|
||||||
isRunning atomic.Bool // Whether the runner is active
|
isRunning atomic.Bool // Whether the runner is active
|
||||||
mu sync.RWMutex // Mutex for thread safety
|
mu sync.RWMutex // Mutex for thread safety
|
||||||
debug bool // Enable debug logging
|
debug bool // Enable debug logging
|
||||||
initHooks []InitHook // Hooks run before script execution
|
scriptDir string // Current script directory
|
||||||
finalizeHooks []FinalizeHook // Hooks run after script execution
|
|
||||||
scriptDir string // Current script directory
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithPoolSize sets the state pool size
|
// WithPoolSize sets the state pool size
|
||||||
|
@ -84,28 +74,12 @@ func WithLibDirs(dirs ...string) RunnerOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithInitHook adds a hook to run before script execution
|
|
||||||
func WithInitHook(hook InitHook) RunnerOption {
|
|
||||||
return func(r *Runner) {
|
|
||||||
r.initHooks = append(r.initHooks, hook)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithFinalizeHook adds a hook to run after script execution
|
|
||||||
func WithFinalizeHook(hook FinalizeHook) RunnerOption {
|
|
||||||
return func(r *Runner) {
|
|
||||||
r.finalizeHooks = append(r.finalizeHooks, hook)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRunner creates a new Runner with a pool of states
|
// NewRunner creates a new Runner with a pool of states
|
||||||
func NewRunner(options ...RunnerOption) (*Runner, error) {
|
func NewRunner(options ...RunnerOption) (*Runner, error) {
|
||||||
// Default configuration
|
// Default configuration
|
||||||
runner := &Runner{
|
runner := &Runner{
|
||||||
poolSize: runtime.GOMAXPROCS(0),
|
poolSize: runtime.GOMAXPROCS(0),
|
||||||
debug: false,
|
debug: false,
|
||||||
initHooks: make([]InitHook, 0, 4),
|
|
||||||
finalizeHooks: make([]FinalizeHook, 0, 4),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply options
|
// Apply options
|
||||||
|
@ -122,6 +96,11 @@ func NewRunner(options ...RunnerOption) (*Runner, error) {
|
||||||
runner.moduleLoader = NewModuleLoader(config)
|
runner.moduleLoader = NewModuleLoader(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enable debug if requested
|
||||||
|
if runner.debug {
|
||||||
|
runner.moduleLoader.EnableDebug()
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize states and pool
|
// Initialize states and pool
|
||||||
runner.states = make([]*State, runner.poolSize)
|
runner.states = make([]*State, runner.poolSize)
|
||||||
runner.statePool = make(chan int, runner.poolSize)
|
runner.statePool = make(chan int, runner.poolSize)
|
||||||
|
@ -145,7 +124,7 @@ func (r *Runner) debugLog(format string, args ...interface{}) {
|
||||||
|
|
||||||
// initializeStates creates and initializes all states in the pool
|
// initializeStates creates and initializes all states in the pool
|
||||||
func (r *Runner) initializeStates() error {
|
func (r *Runner) initializeStates() error {
|
||||||
r.debugLog("is initializing %d states", r.poolSize)
|
r.debugLog("Initializing %d states", r.poolSize)
|
||||||
|
|
||||||
// Create all states
|
// Create all states
|
||||||
for i := 0; i < r.poolSize; i++ {
|
for i := 0; i < r.poolSize; i++ {
|
||||||
|
@ -175,39 +154,36 @@ func (r *Runner) createState(index int) (*State, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create sandbox
|
// Create sandbox
|
||||||
sb := sandbox.NewSandbox()
|
sb := NewSandbox()
|
||||||
if r.debug && verbose {
|
if r.debug {
|
||||||
sb.EnableDebug()
|
sb.EnableDebug()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up require system
|
// Set up sandbox
|
||||||
|
if err := sb.Setup(L); err != nil {
|
||||||
|
L.Cleanup()
|
||||||
|
L.Close()
|
||||||
|
return nil, ErrInitFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up module loader
|
||||||
if err := r.moduleLoader.SetupRequire(L); err != nil {
|
if err := r.moduleLoader.SetupRequire(L); err != nil {
|
||||||
L.Cleanup()
|
L.Cleanup()
|
||||||
L.Close()
|
L.Close()
|
||||||
return nil, ErrInitFailed
|
return nil, ErrInitFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize all core modules from the registry
|
// Preload modules
|
||||||
if err := GlobalRegistry.Initialize(L, index); err != nil {
|
|
||||||
L.Cleanup()
|
|
||||||
L.Close()
|
|
||||||
return nil, ErrInitFailed
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up sandbox after core modules are initialized
|
|
||||||
if err := sb.Setup(L, index); err != nil {
|
|
||||||
L.Cleanup()
|
|
||||||
L.Close()
|
|
||||||
return nil, ErrInitFailed
|
|
||||||
}
|
|
||||||
|
|
||||||
// Preload all modules
|
|
||||||
if err := r.moduleLoader.PreloadModules(L); err != nil {
|
if err := r.moduleLoader.PreloadModules(L); err != nil {
|
||||||
L.Cleanup()
|
L.Cleanup()
|
||||||
L.Close()
|
L.Close()
|
||||||
return nil, errors.New("failed to preload modules")
|
return nil, errors.New("failed to preload modules")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if verbose {
|
||||||
|
r.debugLog("Lua state %d initialized successfully", index)
|
||||||
|
}
|
||||||
|
|
||||||
return &State{
|
return &State{
|
||||||
L: L,
|
L: L,
|
||||||
sandbox: sb,
|
sandbox: sb,
|
||||||
|
@ -216,8 +192,8 @@ func (r *Runner) createState(index int) (*State, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute runs a script with context
|
// Execute runs a script in a sandbox with context
|
||||||
func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *luaCtx.Context, scriptPath string) (any, error) {
|
func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (*Response, error) {
|
||||||
if !r.isRunning.Load() {
|
if !r.isRunning.Load() {
|
||||||
return nil, ErrRunnerClosed
|
return nil, ErrRunnerClosed
|
||||||
}
|
}
|
||||||
|
@ -264,70 +240,17 @@ func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *luaCtx.C
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Run init hooks
|
// Execute in sandbox
|
||||||
for _, hook := range r.initHooks {
|
response, err := state.sandbox.Execute(state.L, bytecode, execCtx)
|
||||||
if err := hook(state.L, execCtx); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get context values
|
|
||||||
var ctxValues map[string]any
|
|
||||||
if execCtx != nil {
|
|
||||||
ctxValues = execCtx.Values
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute in sandbox with optimized context handling
|
|
||||||
var result any
|
|
||||||
var err error
|
|
||||||
|
|
||||||
if execCtx != nil && execCtx.RequestCtx != nil {
|
|
||||||
// Use OptimizedExecute directly with the full context if we have RequestCtx
|
|
||||||
result, err = state.sandbox.OptimizedExecute(state.L, bytecode, &luaCtx.Context{
|
|
||||||
Values: ctxValues,
|
|
||||||
RequestCtx: execCtx.RequestCtx,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
// Otherwise use standard Execute with just values
|
|
||||||
result, err = state.sandbox.Execute(state.L, bytecode, ctxValues)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run finalize hooks
|
return response, nil
|
||||||
for _, hook := range r.finalizeHooks {
|
|
||||||
if hookErr := hook(state.L, execCtx, result); hookErr != nil {
|
|
||||||
return nil, hookErr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for HTTP response if we don't have a RequestCtx or if we still have a result
|
|
||||||
if execCtx == nil || execCtx.RequestCtx == nil || result != nil {
|
|
||||||
httpResp, hasResponse := sandbox.GetHTTPResponse(state.L)
|
|
||||||
if hasResponse {
|
|
||||||
// Set result as body if not already set
|
|
||||||
if httpResp.Body == nil {
|
|
||||||
httpResp.Body = result
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply directly to request context if available
|
|
||||||
if execCtx != nil && execCtx.RequestCtx != nil {
|
|
||||||
sandbox.ApplyHTTPResponse(httpResp, execCtx.RequestCtx)
|
|
||||||
sandbox.ReleaseResponse(httpResp)
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return httpResp, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run executes a Lua script (convenience wrapper)
|
// Run executes a Lua script with immediate context
|
||||||
func (r *Runner) Run(bytecode []byte, execCtx *luaCtx.Context, scriptPath string) (any, error) {
|
func (r *Runner) Run(bytecode []byte, execCtx *Context, scriptPath string) (*Response, error) {
|
||||||
return r.Execute(context.Background(), bytecode, execCtx, scriptPath)
|
return r.Execute(context.Background(), bytecode, execCtx, scriptPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -363,6 +286,7 @@ cleanup:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.debugLog("Runner closed")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -375,6 +299,8 @@ func (r *Runner) RefreshStates() error {
|
||||||
return ErrRunnerClosed
|
return ErrRunnerClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.debugLog("Refreshing all states...")
|
||||||
|
|
||||||
// Drain all states from the pool
|
// Drain all states from the pool
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
@ -408,81 +334,6 @@ cleanup:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddInitHook adds a hook to be called before script execution
|
|
||||||
func (r *Runner) AddInitHook(hook InitHook) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
r.initHooks = append(r.initHooks, hook)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddFinalizeHook adds a hook to be called after script execution
|
|
||||||
func (r *Runner) AddFinalizeHook(hook FinalizeHook) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
r.finalizeHooks = append(r.finalizeHooks, hook)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetStateCount returns the number of initialized states
|
|
||||||
func (r *Runner) GetStateCount() int {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
count := 0
|
|
||||||
for _, state := range r.states {
|
|
||||||
if state != nil {
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return count
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetActiveStateCount returns the number of states currently in use
|
|
||||||
func (r *Runner) GetActiveStateCount() int {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
count := 0
|
|
||||||
for _, state := range r.states {
|
|
||||||
if state != nil && state.inUse {
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return count
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetModuleCount returns the number of loaded modules in the first available state
|
|
||||||
func (r *Runner) GetModuleCount() int {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
if !r.isRunning.Load() {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find first available state
|
|
||||||
for _, state := range r.states {
|
|
||||||
if state != nil && !state.inUse {
|
|
||||||
// Execute a Lua snippet to count modules
|
|
||||||
if res, err := state.L.ExecuteWithResult(`
|
|
||||||
local count = 0
|
|
||||||
for _ in pairs(package.loaded) do
|
|
||||||
count = count + 1
|
|
||||||
end
|
|
||||||
return count
|
|
||||||
`); err == nil {
|
|
||||||
if num, ok := res.(float64); ok {
|
|
||||||
return int(num)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// NotifyFileChanged alerts the runner about file changes
|
// NotifyFileChanged alerts the runner about file changes
|
||||||
func (r *Runner) NotifyFileChanged(filePath string) bool {
|
func (r *Runner) NotifyFileChanged(filePath string) bool {
|
||||||
r.debugLog("File change detected: %s", filePath)
|
r.debugLog("File change detected: %s", filePath)
|
||||||
|
@ -514,9 +365,6 @@ func (r *Runner) RefreshModule(moduleName string) bool {
|
||||||
|
|
||||||
r.debugLog("Refreshing module: %s", moduleName)
|
r.debugLog("Refreshing module: %s", moduleName)
|
||||||
|
|
||||||
// Check if it's a core module
|
|
||||||
coreName, isCore := GlobalRegistry.MatchModuleName(moduleName)
|
|
||||||
|
|
||||||
success := true
|
success := true
|
||||||
for _, state := range r.states {
|
for _, state := range r.states {
|
||||||
if state == nil || state.inUse {
|
if state == nil || state.inUse {
|
||||||
|
@ -526,16 +374,39 @@ func (r *Runner) RefreshModule(moduleName string) bool {
|
||||||
// Invalidate module in Lua
|
// Invalidate module in Lua
|
||||||
if err := state.L.DoString(`package.loaded["` + moduleName + `"] = nil`); err != nil {
|
if err := state.L.DoString(`package.loaded["` + moduleName + `"] = nil`); err != nil {
|
||||||
success = false
|
success = false
|
||||||
continue
|
r.debugLog("Failed to invalidate module %s: %v", moduleName, err)
|
||||||
}
|
|
||||||
|
|
||||||
// For core modules, reinitialize them
|
|
||||||
if isCore {
|
|
||||||
if err := GlobalRegistry.InitializeModule(state.L, coreName); err != nil {
|
|
||||||
success = false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return success
|
return success
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStateCount returns the number of initialized states
|
||||||
|
func (r *Runner) GetStateCount() int {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
for _, state := range r.states {
|
||||||
|
if state != nil {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveStateCount returns the number of states currently in use
|
||||||
|
func (r *Runner) GetActiveStateCount() int {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
for _, state := range r.states {
|
||||||
|
if state != nil && state.inUse {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
345
core/runner/Sandbox.go
Normal file
345
core/runner/Sandbox.go
Normal file
|
@ -0,0 +1,345 @@
|
||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/valyala/bytebufferpool"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
|
"Moonshark/core/utils/logger"
|
||||||
|
|
||||||
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Error represents a simple error string
|
||||||
|
type Error string
|
||||||
|
|
||||||
|
func (e Error) Error() string {
|
||||||
|
return string(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error types
|
||||||
|
var (
|
||||||
|
ErrSandboxNotInitialized = Error("sandbox not initialized")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Sandbox provides a secure execution environment for Lua scripts
|
||||||
|
type Sandbox struct {
|
||||||
|
modules map[string]any
|
||||||
|
debug bool
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSandbox creates a new sandbox environment
|
||||||
|
func NewSandbox() *Sandbox {
|
||||||
|
return &Sandbox{
|
||||||
|
modules: make(map[string]any, 8),
|
||||||
|
debug: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableDebug turns on debug logging
|
||||||
|
func (s *Sandbox) EnableDebug() {
|
||||||
|
s.debug = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// debugLog logs a message if debug mode is enabled
|
||||||
|
func (s *Sandbox) debugLog(format string, args ...interface{}) {
|
||||||
|
if s.debug {
|
||||||
|
logger.Debug("Sandbox "+format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddModule adds a module to the sandbox environment
|
||||||
|
func (s *Sandbox) AddModule(name string, module any) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.modules[name] = module
|
||||||
|
s.debugLog("Added module: %s", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup initializes the sandbox in a Lua state
|
||||||
|
func (s *Sandbox) Setup(state *luajit.State) error {
|
||||||
|
s.debugLog("Setting up sandbox...")
|
||||||
|
|
||||||
|
// Load the sandbox code
|
||||||
|
if err := loadSandboxIntoState(state); err != nil {
|
||||||
|
s.debugLog("Failed to load sandbox: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register core functions
|
||||||
|
if err := s.registerCoreFunctions(state); err != nil {
|
||||||
|
s.debugLog("Failed to register core functions: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register custom modules in the global environment
|
||||||
|
s.mu.RLock()
|
||||||
|
for name, module := range s.modules {
|
||||||
|
s.debugLog("Registering module: %s", name)
|
||||||
|
if err := state.PushValue(module); err != nil {
|
||||||
|
s.mu.RUnlock()
|
||||||
|
s.debugLog("Failed to register module %s: %v", name, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
state.SetGlobal(name)
|
||||||
|
}
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
s.debugLog("Sandbox setup complete")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerCoreFunctions registers all built-in functions in the Lua state
|
||||||
|
func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
|
||||||
|
// Register HTTP functions
|
||||||
|
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register utility functions
|
||||||
|
if err := state.RegisterGoFunction("__generate_token", generateToken); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Additional registrations can be added here
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute runs a Lua script in the sandbox with the given context
|
||||||
|
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
|
||||||
|
s.debugLog("Executing script...")
|
||||||
|
|
||||||
|
// Create a response object
|
||||||
|
response := NewResponse()
|
||||||
|
|
||||||
|
// Get a buffer for string operations
|
||||||
|
buf := bytebufferpool.Get()
|
||||||
|
defer bytebufferpool.Put(buf)
|
||||||
|
|
||||||
|
// Load bytecode
|
||||||
|
if err := state.LoadBytecode(bytecode, "script"); err != nil {
|
||||||
|
ReleaseResponse(response)
|
||||||
|
s.debugLog("Failed to load bytecode: %v", err)
|
||||||
|
return nil, fmt.Errorf("failed to load script: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize session data in Lua
|
||||||
|
if ctx.SessionID != "" {
|
||||||
|
// Set session ID
|
||||||
|
state.PushString(ctx.SessionID)
|
||||||
|
state.SetGlobal("__session_id")
|
||||||
|
|
||||||
|
// Set session data
|
||||||
|
if err := state.PushTable(ctx.SessionData); err != nil {
|
||||||
|
ReleaseResponse(response)
|
||||||
|
s.debugLog("Failed to push session data: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
state.SetGlobal("__session_data")
|
||||||
|
|
||||||
|
// Reset modification flag
|
||||||
|
state.PushBoolean(false)
|
||||||
|
state.SetGlobal("__session_modified")
|
||||||
|
} else {
|
||||||
|
// Initialize empty session
|
||||||
|
if err := state.DoString("__session_data = {}; __session_modified = false"); err != nil {
|
||||||
|
s.debugLog("Failed to initialize empty session data: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up context values for execution
|
||||||
|
if err := state.PushTable(ctx.Values); err != nil {
|
||||||
|
ReleaseResponse(response)
|
||||||
|
s.debugLog("Failed to push context values: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the execution function
|
||||||
|
state.GetGlobal("__execute_script")
|
||||||
|
if !state.IsFunction(-1) {
|
||||||
|
state.Pop(1) // Pop non-function
|
||||||
|
ReleaseResponse(response)
|
||||||
|
s.debugLog("__execute_script is not a function")
|
||||||
|
return nil, ErrSandboxNotInitialized
|
||||||
|
}
|
||||||
|
|
||||||
|
// Push function and context to stack
|
||||||
|
state.PushCopy(-2) // bytecode
|
||||||
|
state.PushCopy(-2) // context
|
||||||
|
|
||||||
|
// Remove duplicates
|
||||||
|
state.Remove(-4)
|
||||||
|
state.Remove(-3)
|
||||||
|
|
||||||
|
// Execute with 2 args, 1 result
|
||||||
|
if err := state.Call(2, 1); err != nil {
|
||||||
|
ReleaseResponse(response)
|
||||||
|
s.debugLog("Execution failed: %v", err)
|
||||||
|
return nil, fmt.Errorf("script execution failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set response body from result
|
||||||
|
body, err := state.ToValue(-1)
|
||||||
|
if err == nil {
|
||||||
|
response.Body = body
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Extract HTTP response data from Lua state
|
||||||
|
s.extractResponseData(state, response)
|
||||||
|
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractResponseData pulls response info from the Lua state
|
||||||
|
func (s *Sandbox) extractResponseData(state *luajit.State, response *Response) {
|
||||||
|
// Get HTTP response
|
||||||
|
state.GetGlobal("__http_responses")
|
||||||
|
if !state.IsNil(-1) && state.IsTable(-1) {
|
||||||
|
state.PushNumber(1)
|
||||||
|
state.GetTable(-2)
|
||||||
|
|
||||||
|
if !state.IsNil(-1) && state.IsTable(-1) {
|
||||||
|
// Extract status
|
||||||
|
state.GetField(-1, "status")
|
||||||
|
if state.IsNumber(-1) {
|
||||||
|
response.Status = int(state.ToNumber(-1))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Extract headers
|
||||||
|
state.GetField(-1, "headers")
|
||||||
|
if state.IsTable(-1) {
|
||||||
|
state.PushNil() // Start iteration
|
||||||
|
for state.Next(-2) {
|
||||||
|
if state.IsString(-2) && state.IsString(-1) {
|
||||||
|
key := state.ToString(-2)
|
||||||
|
value := state.ToString(-1)
|
||||||
|
response.Headers[key] = value
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Extract cookies
|
||||||
|
state.GetField(-1, "cookies")
|
||||||
|
if state.IsTable(-1) {
|
||||||
|
length := state.GetTableLength(-1)
|
||||||
|
for i := 1; i <= length; i++ {
|
||||||
|
state.PushNumber(float64(i))
|
||||||
|
state.GetTable(-2)
|
||||||
|
|
||||||
|
if state.IsTable(-1) {
|
||||||
|
s.extractCookie(state, response)
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Extract metadata if present
|
||||||
|
state.GetField(-1, "metadata")
|
||||||
|
if state.IsTable(-1) {
|
||||||
|
table, err := state.ToTable(-1)
|
||||||
|
if err == nil {
|
||||||
|
for k, v := range table {
|
||||||
|
response.Metadata[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Extract session data
|
||||||
|
state.GetGlobal("__session_modified")
|
||||||
|
if state.IsBoolean(-1) && state.ToBoolean(-1) {
|
||||||
|
response.SessionModified = true
|
||||||
|
|
||||||
|
// Get session ID
|
||||||
|
state.GetGlobal("__session_id")
|
||||||
|
if state.IsString(-1) {
|
||||||
|
response.SessionID = state.ToString(-1)
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get session data
|
||||||
|
state.GetGlobal("__session_data")
|
||||||
|
if state.IsTable(-1) {
|
||||||
|
sessionData, err := state.ToTable(-1)
|
||||||
|
if err == nil {
|
||||||
|
for k, v := range sessionData {
|
||||||
|
response.SessionData[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractCookie pulls cookie data from the current table on the stack
|
||||||
|
func (s *Sandbox) extractCookie(state *luajit.State, response *Response) {
|
||||||
|
cookie := fasthttp.AcquireCookie()
|
||||||
|
|
||||||
|
// Get name (required)
|
||||||
|
state.GetField(-1, "name")
|
||||||
|
if !state.IsString(-1) {
|
||||||
|
state.Pop(1)
|
||||||
|
fasthttp.ReleaseCookie(cookie)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cookie.SetKey(state.ToString(-1))
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get value
|
||||||
|
state.GetField(-1, "value")
|
||||||
|
if state.IsString(-1) {
|
||||||
|
cookie.SetValue(state.ToString(-1))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get path
|
||||||
|
state.GetField(-1, "path")
|
||||||
|
if state.IsString(-1) {
|
||||||
|
cookie.SetPath(state.ToString(-1))
|
||||||
|
} else {
|
||||||
|
cookie.SetPath("/") // Default
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get domain
|
||||||
|
state.GetField(-1, "domain")
|
||||||
|
if state.IsString(-1) {
|
||||||
|
cookie.SetDomain(state.ToString(-1))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get other parameters
|
||||||
|
state.GetField(-1, "http_only")
|
||||||
|
if state.IsBoolean(-1) {
|
||||||
|
cookie.SetHTTPOnly(state.ToBoolean(-1))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
state.GetField(-1, "secure")
|
||||||
|
if state.IsBoolean(-1) {
|
||||||
|
cookie.SetSecure(state.ToBoolean(-1))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
state.GetField(-1, "max_age")
|
||||||
|
if state.IsNumber(-1) {
|
||||||
|
cookie.SetMaxAge(int(state.ToNumber(-1)))
|
||||||
|
}
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
response.Cookies = append(response.Cookies, cookie)
|
||||||
|
}
|
|
@ -1,241 +0,0 @@
|
||||||
package runner
|
|
||||||
|
|
||||||
import (
|
|
||||||
luaCtx "Moonshark/core/runner/context"
|
|
||||||
"Moonshark/core/runner/sandbox"
|
|
||||||
"Moonshark/core/sessions"
|
|
||||||
"Moonshark/core/utils/logger"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SessionHandler handles session management for Lua scripts
|
|
||||||
type SessionHandler struct {
|
|
||||||
manager *sessions.SessionManager
|
|
||||||
debugLog bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSessionHandler creates a new session handler
|
|
||||||
func NewSessionHandler(manager *sessions.SessionManager) *SessionHandler {
|
|
||||||
return &SessionHandler{
|
|
||||||
manager: manager,
|
|
||||||
debugLog: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// EnableDebug enables debug logging
|
|
||||||
func (h *SessionHandler) EnableDebug() {
|
|
||||||
h.debugLog = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithSessionManager creates a RunnerOption to add session support
|
|
||||||
func WithSessionManager(manager *sessions.SessionManager) RunnerOption {
|
|
||||||
return func(r *Runner) {
|
|
||||||
handler := NewSessionHandler(manager)
|
|
||||||
r.AddInitHook(handler.preRequestHook)
|
|
||||||
r.AddFinalizeHook(handler.postRequestHook)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// preRequestHook initializes session before script execution
|
|
||||||
func (h *SessionHandler) preRequestHook(state *luajit.State, ctx *luaCtx.Context) error {
|
|
||||||
if ctx == nil || ctx.Values["_request_cookies"] == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract cookies from context
|
|
||||||
cookies, ok := ctx.Values["_request_cookies"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the session ID from cookies
|
|
||||||
cookieName := h.manager.CookieOptions()["name"].(string)
|
|
||||||
var sessionID string
|
|
||||||
|
|
||||||
// Check if our session cookie exists
|
|
||||||
if cookieValue, exists := cookies[cookieName]; exists {
|
|
||||||
if strValue, ok := cookieValue.(string); ok && strValue != "" {
|
|
||||||
sessionID = strValue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create new session if needed
|
|
||||||
if sessionID == "" {
|
|
||||||
session := h.manager.CreateSession()
|
|
||||||
sessionID = session.ID
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store the session ID in the context
|
|
||||||
ctx.Set("_session_id", sessionID)
|
|
||||||
|
|
||||||
// Get session data
|
|
||||||
session := h.manager.GetSession(sessionID)
|
|
||||||
sessionData := session.GetAll()
|
|
||||||
|
|
||||||
// Set session data in Lua state
|
|
||||||
return SetSessionData(state, sessionID, sessionData)
|
|
||||||
}
|
|
||||||
|
|
||||||
// postRequestHook handles session after script execution
|
|
||||||
func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *luaCtx.Context, result any) error {
|
|
||||||
// Check if session was modified
|
|
||||||
modifiedID, modifiedData, modified := GetSessionData(state)
|
|
||||||
if !modified {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the original session ID from context
|
|
||||||
var sessionID string
|
|
||||||
if ctx != nil {
|
|
||||||
if id, ok := ctx.Values["_session_id"].(string); ok {
|
|
||||||
sessionID = id
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use the original session ID if the modified one is empty
|
|
||||||
if modifiedID == "" {
|
|
||||||
modifiedID = sessionID
|
|
||||||
}
|
|
||||||
|
|
||||||
if modifiedID == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update session in manager
|
|
||||||
session := h.manager.GetSession(modifiedID)
|
|
||||||
session.Clear() // clear to sync deleted values
|
|
||||||
for k, v := range modifiedData {
|
|
||||||
session.Set(k, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
h.manager.SaveSession(session)
|
|
||||||
|
|
||||||
// Add session cookie to result if it's an HTTP response
|
|
||||||
if httpResp, ok := result.(*sandbox.HTTPResponse); ok {
|
|
||||||
h.addSessionCookie(httpResp, modifiedID)
|
|
||||||
} else if ctx != nil && ctx.RequestCtx != nil {
|
|
||||||
// Add cookie directly to the RequestCtx when result is not an HTTP response
|
|
||||||
h.addSessionCookieToRequestCtx(ctx.RequestCtx, modifiedID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addSessionCookie adds a session cookie to an HTTP response
|
|
||||||
func (h *SessionHandler) addSessionCookie(resp *sandbox.HTTPResponse, sessionID string) {
|
|
||||||
// Get cookie options
|
|
||||||
opts := h.manager.CookieOptions()
|
|
||||||
|
|
||||||
// Check if session cookie is already set
|
|
||||||
cookieName := opts["name"].(string)
|
|
||||||
for _, cookie := range resp.Cookies {
|
|
||||||
if string(cookie.Key()) == cookieName {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create and add cookie
|
|
||||||
cookie := fasthttp.AcquireCookie()
|
|
||||||
cookie.SetKey(cookieName)
|
|
||||||
cookie.SetValue(sessionID)
|
|
||||||
cookie.SetPath(opts["path"].(string))
|
|
||||||
cookie.SetHTTPOnly(opts["http_only"].(bool))
|
|
||||||
cookie.SetMaxAge(opts["max_age"].(int))
|
|
||||||
|
|
||||||
// Optional cookie parameters
|
|
||||||
if domain, ok := opts["domain"].(string); ok && domain != "" {
|
|
||||||
cookie.SetDomain(domain)
|
|
||||||
}
|
|
||||||
|
|
||||||
if secure, ok := opts["secure"].(bool); ok {
|
|
||||||
cookie.SetSecure(secure)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.Cookies = append(resp.Cookies, cookie)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *SessionHandler) addSessionCookieToRequestCtx(ctx *fasthttp.RequestCtx, sessionID string) {
|
|
||||||
// Get cookie options
|
|
||||||
opts := h.manager.CookieOptions()
|
|
||||||
cookieName := opts["name"].(string)
|
|
||||||
|
|
||||||
// Create cookie
|
|
||||||
cookie := fasthttp.AcquireCookie()
|
|
||||||
defer fasthttp.ReleaseCookie(cookie)
|
|
||||||
|
|
||||||
cookie.SetKey(cookieName)
|
|
||||||
cookie.SetValue(sessionID)
|
|
||||||
cookie.SetPath(opts["path"].(string))
|
|
||||||
cookie.SetHTTPOnly(opts["http_only"].(bool))
|
|
||||||
cookie.SetMaxAge(opts["max_age"].(int))
|
|
||||||
|
|
||||||
// Optional cookie parameters
|
|
||||||
if domain, ok := opts["domain"].(string); ok && domain != "" {
|
|
||||||
cookie.SetDomain(domain)
|
|
||||||
}
|
|
||||||
|
|
||||||
if secure, ok := opts["secure"].(bool); ok {
|
|
||||||
cookie.SetSecure(secure)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.Response.Header.SetCookie(cookie)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSessionData extracts session data from Lua state
|
|
||||||
func GetSessionData(state *luajit.State) (string, map[string]any, bool) {
|
|
||||||
// Check if session was modified
|
|
||||||
state.GetGlobal("__session_modified")
|
|
||||||
modified := state.ToBoolean(-1)
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
if !modified {
|
|
||||||
return "", nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get session ID
|
|
||||||
state.GetGlobal("__session_id")
|
|
||||||
sessionID := state.ToString(-1)
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get session data
|
|
||||||
state.GetGlobal("__session_data")
|
|
||||||
if !state.IsTable(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
return sessionID, nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := state.ToTable(-1)
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to extract session data: %v", err)
|
|
||||||
return sessionID, nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
return sessionID, data, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSessionData sets session data in Lua state
|
|
||||||
func SetSessionData(state *luajit.State, sessionID string, data map[string]any) error {
|
|
||||||
// Set session ID
|
|
||||||
state.PushString(sessionID)
|
|
||||||
state.SetGlobal("__session_id")
|
|
||||||
|
|
||||||
// Set session data
|
|
||||||
if data == nil {
|
|
||||||
data = make(map[string]any)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := state.PushTable(data); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
state.SetGlobal("__session_data")
|
|
||||||
|
|
||||||
// Reset modification flag
|
|
||||||
state.PushBoolean(false)
|
|
||||||
state.SetGlobal("__session_modified")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -14,9 +14,6 @@ __ready_modules = {}
|
||||||
__session_data = {}
|
__session_data = {}
|
||||||
__session_id = nil
|
__session_id = nil
|
||||||
__session_modified = false
|
__session_modified = false
|
||||||
__env_system = {
|
|
||||||
base_env = {}
|
|
||||||
}
|
|
||||||
|
|
||||||
-- ======================================================================
|
-- ======================================================================
|
||||||
-- CORE SANDBOX FUNCTIONALITY
|
-- CORE SANDBOX FUNCTIONALITY
|
||||||
|
@ -44,7 +41,7 @@ end
|
||||||
function __execute_script(fn, ctx)
|
function __execute_script(fn, ctx)
|
||||||
-- Clear previous responses
|
-- Clear previous responses
|
||||||
__http_responses[1] = nil
|
__http_responses[1] = nil
|
||||||
|
|
||||||
-- Reset session modification flag
|
-- Reset session modification flag
|
||||||
__session_modified = false
|
__session_modified = false
|
||||||
|
|
||||||
|
@ -63,75 +60,6 @@ function __execute_script(fn, ctx)
|
||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
|
|
||||||
-- ======================================================================
|
|
||||||
-- MODULE LOADING SYSTEM
|
|
||||||
-- ======================================================================
|
|
||||||
|
|
||||||
-- Setup environment-aware require function
|
|
||||||
function __setup_require(env)
|
|
||||||
-- Create require function specific to this environment
|
|
||||||
env.require = function(modname)
|
|
||||||
-- Check if already loaded
|
|
||||||
if package.loaded[modname] then
|
|
||||||
return package.loaded[modname]
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Check preloaded modules
|
|
||||||
if __ready_modules[modname] then
|
|
||||||
local loader = package.preload[modname]
|
|
||||||
if loader then
|
|
||||||
-- Set environment for loader
|
|
||||||
setfenv(loader, env)
|
|
||||||
|
|
||||||
-- Execute and store result
|
|
||||||
local result = loader()
|
|
||||||
if result == nil then
|
|
||||||
result = true
|
|
||||||
end
|
|
||||||
|
|
||||||
package.loaded[modname] = result
|
|
||||||
return result
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Direct file load as fallback
|
|
||||||
if __module_paths[modname] then
|
|
||||||
local path = __module_paths[modname]
|
|
||||||
local chunk, err = loadfile(path)
|
|
||||||
if chunk then
|
|
||||||
setfenv(chunk, env)
|
|
||||||
local result = chunk()
|
|
||||||
if result == nil then
|
|
||||||
result = true
|
|
||||||
end
|
|
||||||
package.loaded[modname] = result
|
|
||||||
return result
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Full path search as last resort
|
|
||||||
local errors = {}
|
|
||||||
for path in package.path:gmatch("[^;]+") do
|
|
||||||
local file_path = path:gsub("?", modname:gsub("%.", "/"))
|
|
||||||
local chunk, err = loadfile(file_path)
|
|
||||||
if chunk then
|
|
||||||
setfenv(chunk, env)
|
|
||||||
local result = chunk()
|
|
||||||
if result == nil then
|
|
||||||
result = true
|
|
||||||
end
|
|
||||||
package.loaded[modname] = result
|
|
||||||
return result
|
|
||||||
end
|
|
||||||
table.insert(errors, "\tno file '" .. file_path .. "'")
|
|
||||||
end
|
|
||||||
|
|
||||||
error("module '" .. modname .. "' not found:\n" .. table.concat(errors, "\n"), 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
return env
|
|
||||||
end
|
|
||||||
|
|
||||||
-- ======================================================================
|
-- ======================================================================
|
||||||
-- HTTP MODULE
|
-- HTTP MODULE
|
||||||
-- ======================================================================
|
-- ======================================================================
|
||||||
|
@ -166,6 +94,18 @@ local http = {
|
||||||
http.set_header("Content-Type", content_type)
|
http.set_header("Content-Type", content_type)
|
||||||
end,
|
end,
|
||||||
|
|
||||||
|
-- Set metadata (arbitrary data to be returned with response)
|
||||||
|
set_metadata = function(key, value)
|
||||||
|
if type(key) ~= "string" then
|
||||||
|
error("http.set_metadata: key must be a string", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
local resp = __http_responses[1] or {}
|
||||||
|
resp.metadata = resp.metadata or {}
|
||||||
|
resp.metadata[key] = value
|
||||||
|
__http_responses[1] = resp
|
||||||
|
end,
|
||||||
|
|
||||||
-- HTTP client submodule
|
-- HTTP client submodule
|
||||||
client = {
|
client = {
|
||||||
-- Generic request function
|
-- Generic request function
|
||||||
|
@ -213,10 +153,7 @@ local http = {
|
||||||
-- Simple HEAD request
|
-- Simple HEAD request
|
||||||
head = function(url, options)
|
head = function(url, options)
|
||||||
options = options or {}
|
options = options or {}
|
||||||
local old_options = options
|
return http.client.request("HEAD", url, nil, options)
|
||||||
options = {headers = old_options.headers, timeout = old_options.timeout, query = old_options.query}
|
|
||||||
local response = http.client.request("HEAD", url, nil, options)
|
|
||||||
return response
|
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Simple OPTIONS request
|
-- Simple OPTIONS request
|
||||||
|
@ -265,13 +202,13 @@ local http = {
|
||||||
}
|
}
|
||||||
|
|
||||||
-- ======================================================================
|
-- ======================================================================
|
||||||
-- COOKIE MODULE
|
-- COOKIE MODULE
|
||||||
-- ======================================================================
|
-- ======================================================================
|
||||||
|
|
||||||
-- Cookie module implementation
|
-- Cookie module implementation
|
||||||
local cookie = {
|
local cookie = {
|
||||||
-- Set a cookie
|
-- Set a cookie
|
||||||
set = function(name, value, options, ...)
|
set = function(name, value, options)
|
||||||
if type(name) ~= "string" then
|
if type(name) ~= "string" then
|
||||||
error("cookie.set: name must be a string", 2)
|
error("cookie.set: name must be a string", 2)
|
||||||
end
|
end
|
||||||
|
@ -281,20 +218,8 @@ local cookie = {
|
||||||
resp.cookies = resp.cookies or {}
|
resp.cookies = resp.cookies or {}
|
||||||
__http_responses[1] = resp
|
__http_responses[1] = resp
|
||||||
|
|
||||||
-- Handle options as table or legacy params
|
-- Handle options as table
|
||||||
local opts = {}
|
local opts = options or {}
|
||||||
if type(options) == "table" then
|
|
||||||
opts = options
|
|
||||||
elseif options ~= nil then
|
|
||||||
-- Legacy support: options is actually 'expires'
|
|
||||||
opts.expires = options
|
|
||||||
-- Check for other legacy params (4th-7th args)
|
|
||||||
local args = {...}
|
|
||||||
if args[1] then opts.path = args[1] end
|
|
||||||
if args[2] then opts.domain = args[2] end
|
|
||||||
if args[3] then opts.secure = args[3] end
|
|
||||||
if args[4] ~= nil then opts.http_only = args[4] end
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Create cookie table
|
-- Create cookie table
|
||||||
local cookie = {
|
local cookie = {
|
||||||
|
@ -314,10 +239,8 @@ local cookie = {
|
||||||
elseif opts.expires < 0 then
|
elseif opts.expires < 0 then
|
||||||
cookie.expires = 1
|
cookie.expires = 1
|
||||||
cookie.max_age = 0
|
cookie.max_age = 0
|
||||||
else
|
|
||||||
-- opts.expires == 0: Session cookie
|
|
||||||
-- Do nothing (omitting both expires and max-age creates a session cookie)
|
|
||||||
end
|
end
|
||||||
|
-- opts.expires == 0: Session cookie (omitting both expires and max-age)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -342,8 +265,13 @@ local cookie = {
|
||||||
local env = getfenv(2)
|
local env = getfenv(2)
|
||||||
|
|
||||||
-- Check if context exists and has cookies
|
-- Check if context exists and has cookies
|
||||||
if env.ctx and env.ctx.cookies and env.ctx.cookies[name] then
|
if env.ctx and env.ctx.cookies then
|
||||||
return tostring(env.ctx.cookies[name])
|
return env.ctx.cookies[name]
|
||||||
|
end
|
||||||
|
|
||||||
|
-- If context has request_cookies map
|
||||||
|
if env.ctx and env.ctx._request_cookies then
|
||||||
|
return env.ctx._request_cookies[name]
|
||||||
end
|
end
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -361,7 +289,7 @@ local cookie = {
|
||||||
}
|
}
|
||||||
|
|
||||||
-- ======================================================================
|
-- ======================================================================
|
||||||
-- SESSION MODULE
|
-- SESSION MODULE
|
||||||
-- ======================================================================
|
-- ======================================================================
|
||||||
|
|
||||||
-- Session module implementation
|
-- Session module implementation
|
||||||
|
@ -372,7 +300,7 @@ local session = {
|
||||||
error("session.get: key must be a string", 2)
|
error("session.get: key must be a string", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
if __session_data and __session_data[key] then
|
if __session_data and __session_data[key] ~= nil then
|
||||||
return __session_data[key]
|
return __session_data[key]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -469,7 +397,7 @@ local csrf = {
|
||||||
error("CSRF protection requires the session module", 2)
|
error("CSRF protection requires the session module", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
local token = util.generate_token(length)
|
local token = __generate_token(length)
|
||||||
session.set(csrf.TOKEN_KEY, token)
|
session.set(csrf.TOKEN_KEY, token)
|
||||||
return token
|
return token
|
||||||
end,
|
end,
|
||||||
|
@ -495,48 +423,133 @@ local csrf = {
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Verify a given token against the session token
|
-- Verify a given token against the session token
|
||||||
verify = function(token, field_name)
|
verify = function(token, field_name)
|
||||||
field_name = field_name or csrf.DEFAULT_FIELD
|
field_name = field_name or csrf.DEFAULT_FIELD
|
||||||
|
|
||||||
local env = getfenv(2)
|
local env = getfenv(2)
|
||||||
|
|
||||||
local form = nil
|
local form = nil
|
||||||
if env.ctx and env.ctx.form then
|
if env.ctx and env.ctx._request_form then
|
||||||
form = env.ctx.form
|
form = env.ctx._request_form
|
||||||
else
|
elseif env.ctx and env.ctx.form then
|
||||||
return false
|
form = env.ctx.form
|
||||||
end
|
else
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
token = token or form[field_name]
|
token = token or form[field_name]
|
||||||
if not token then
|
if not token then
|
||||||
return false
|
return false
|
||||||
end
|
end
|
||||||
|
|
||||||
local session_token = session.get(csrf.TOKEN_KEY)
|
local session_token = session.get(csrf.TOKEN_KEY)
|
||||||
if not session_token then
|
if not session_token then
|
||||||
return false
|
return false
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Constant-time comparison to prevent timing attacks
|
-- Constant-time comparison to prevent timing attacks
|
||||||
-- This is safe since Lua strings are immutable
|
if #token ~= #session_token then
|
||||||
if #token ~= #session_token then
|
return false
|
||||||
return false
|
end
|
||||||
end
|
|
||||||
|
|
||||||
local result = true
|
local result = true
|
||||||
for i = 1, #token do
|
for i = 1, #token do
|
||||||
if token:sub(i, i) ~= session_token:sub(i, i) then
|
if token:sub(i, i) ~= session_token:sub(i, i) then
|
||||||
result = false
|
result = false
|
||||||
-- Don't break early - continue to prevent timing attacks
|
-- Don't break early - continue to prevent timing attacks
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
return result
|
return result
|
||||||
end
|
end
|
||||||
}
|
}
|
||||||
|
|
||||||
-- ======================================================================
|
-- ======================================================================
|
||||||
-- REGISTER MODULES GLOBALLY
|
-- UTIL MODULE
|
||||||
|
-- ======================================================================
|
||||||
|
|
||||||
|
-- Utility module implementation
|
||||||
|
local util = {
|
||||||
|
-- Generate a token (wrapper around __generate_token)
|
||||||
|
generate_token = function(length)
|
||||||
|
return __generate_token(length or 32)
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Simple JSON stringify (for when you just need a quick string)
|
||||||
|
json_encode = function(value)
|
||||||
|
if type(value) == "table" then
|
||||||
|
local json = "{"
|
||||||
|
local sep = ""
|
||||||
|
for k, v in pairs(value) do
|
||||||
|
json = json .. sep
|
||||||
|
if type(k) == "number" then
|
||||||
|
-- Array-like
|
||||||
|
json = json .. util.json_encode(v)
|
||||||
|
else
|
||||||
|
-- Object-like
|
||||||
|
json = json .. '"' .. k .. '":' .. util.json_encode(v)
|
||||||
|
end
|
||||||
|
sep = ","
|
||||||
|
end
|
||||||
|
return json .. "}"
|
||||||
|
elseif type(value) == "string" then
|
||||||
|
return '"' .. value:gsub('"', '\\"'):gsub('\n', '\\n') .. '"'
|
||||||
|
elseif type(value) == "number" then
|
||||||
|
return tostring(value)
|
||||||
|
elseif type(value) == "boolean" then
|
||||||
|
return value and "true" or "false"
|
||||||
|
elseif value == nil then
|
||||||
|
return "null"
|
||||||
|
end
|
||||||
|
return '"' .. tostring(value) .. '"'
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Deep copy of tables
|
||||||
|
deep_copy = function(obj)
|
||||||
|
if type(obj) ~= 'table' then return obj end
|
||||||
|
local res = {}
|
||||||
|
for k, v in pairs(obj) do res[k] = util.deep_copy(v) end
|
||||||
|
return res
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Merge tables
|
||||||
|
merge_tables = function(t1, t2)
|
||||||
|
if type(t1) ~= 'table' or type(t2) ~= 'table' then
|
||||||
|
error("Both arguments must be tables", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
local result = util.deep_copy(t1)
|
||||||
|
for k, v in pairs(t2) do
|
||||||
|
if type(v) == 'table' and type(result[k]) == 'table' then
|
||||||
|
result[k] = util.merge_tables(result[k], v)
|
||||||
|
else
|
||||||
|
result[k] = v
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return result
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- String utilities
|
||||||
|
string = {
|
||||||
|
-- Trim whitespace
|
||||||
|
trim = function(s)
|
||||||
|
return (s:gsub("^%s*(.-)%s*$", "%1"))
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Split string
|
||||||
|
split = function(s, delimiter)
|
||||||
|
delimiter = delimiter or ","
|
||||||
|
local result = {}
|
||||||
|
for match in (s..delimiter):gmatch("(.-)"..delimiter) do
|
||||||
|
table.insert(result, match)
|
||||||
|
end
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
-- ======================================================================
|
||||||
|
-- REGISTER MODULES GLOBALLY
|
||||||
-- ======================================================================
|
-- ======================================================================
|
||||||
|
|
||||||
-- Install modules in global scope
|
-- Install modules in global scope
|
||||||
|
@ -544,9 +557,4 @@ _G.http = http
|
||||||
_G.cookie = cookie
|
_G.cookie = cookie
|
||||||
_G.session = session
|
_G.session = session
|
||||||
_G.csrf = csrf
|
_G.csrf = csrf
|
||||||
|
_G.util = util
|
||||||
-- Register modules in sandbox base environment
|
|
||||||
__env_system.base_env.http = http
|
|
||||||
__env_system.base_env.cookie = cookie
|
|
||||||
__env_system.base_env.session = session
|
|
||||||
__env_system.base_env.csrf = csrf
|
|
|
@ -1,98 +0,0 @@
|
||||||
package sandbox
|
|
||||||
|
|
||||||
import (
|
|
||||||
_ "embed"
|
|
||||||
|
|
||||||
"Moonshark/core/utils/logger"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
)
|
|
||||||
|
|
||||||
//go:embed lua/sandbox.lua
|
|
||||||
var sandboxLua string
|
|
||||||
|
|
||||||
// InitializeSandbox loads the embedded Lua sandbox code into a Lua state
|
|
||||||
func InitializeSandbox(state *luajit.State) error {
|
|
||||||
// Compile once, use many times
|
|
||||||
bytecodeOnce.Do(precompileSandbox)
|
|
||||||
|
|
||||||
if sandboxBytecode != nil {
|
|
||||||
logger.Debug("Loading sandbox.lua from precompiled bytecode")
|
|
||||||
return state.LoadAndRunBytecode(sandboxBytecode, "sandbox.lua")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback if compilation failed
|
|
||||||
logger.Warning("Using non-precompiled sandbox.lua (bytecode compilation failed)")
|
|
||||||
return state.DoString(sandboxLua)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ModuleInitializers stores initializer functions for core modules
|
|
||||||
type ModuleInitializers struct {
|
|
||||||
HTTP func(*luajit.State) error
|
|
||||||
Util func(*luajit.State) error
|
|
||||||
Session func(*luajit.State) error
|
|
||||||
Cookie func(*luajit.State) error
|
|
||||||
CSRF func(*luajit.State) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultInitializers returns the default set of initializers
|
|
||||||
func DefaultInitializers() *ModuleInitializers {
|
|
||||||
return &ModuleInitializers{
|
|
||||||
HTTP: func(state *luajit.State) error {
|
|
||||||
// Register the native Go function first
|
|
||||||
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
|
|
||||||
logger.Error("[HTTP Module] Failed to register __http_request function: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
Util: func(state *luajit.State) error {
|
|
||||||
// Register util functions
|
|
||||||
return RegisterModule(state, "util", UtilModuleFunctions())
|
|
||||||
},
|
|
||||||
Session: func(state *luajit.State) error {
|
|
||||||
// Session doesn't need special initialization
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
Cookie: func(state *luajit.State) error {
|
|
||||||
// Cookie doesn't need special initialization
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
CSRF: func(state *luajit.State) error {
|
|
||||||
// CSRF doesn't need special initialization
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitializeAll initializes all modules in the Lua state
|
|
||||||
func InitializeAll(state *luajit.State, initializers *ModuleInitializers) error {
|
|
||||||
// Set up dependencies first
|
|
||||||
if err := initializers.Util(state); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := initializers.HTTP(state); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load the embedded sandbox code
|
|
||||||
if err := InitializeSandbox(state); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize the rest of the modules
|
|
||||||
if err := initializers.Session(state); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := initializers.Cookie(state); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := initializers.CSRF(state); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,590 +0,0 @@
|
||||||
package sandbox
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/goccy/go-json"
|
|
||||||
"github.com/valyala/bytebufferpool"
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
|
|
||||||
"Moonshark/core/utils/logger"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SessionHandler interface for session management
|
|
||||||
type SessionHandler interface {
|
|
||||||
LoadSession(ctx *fasthttp.RequestCtx) (string, map[string]any)
|
|
||||||
SaveSession(ctx *fasthttp.RequestCtx, sessionID string, data map[string]any) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// HTTPResponse represents an HTTP response from Lua
|
|
||||||
type HTTPResponse struct {
|
|
||||||
Status int `json:"status"`
|
|
||||||
Headers map[string]string `json:"headers"`
|
|
||||||
Body any `json:"body"`
|
|
||||||
Cookies []*fasthttp.Cookie `json:"-"`
|
|
||||||
SessionModified bool `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Response pool to reduce allocations
|
|
||||||
var responsePool = sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
return &HTTPResponse{
|
|
||||||
Status: 200,
|
|
||||||
Headers: make(map[string]string, 8),
|
|
||||||
Cookies: make([]*fasthttp.Cookie, 0, 4),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default HTTP client with sensible timeout
|
|
||||||
var defaultFastClient fasthttp.Client = fasthttp.Client{
|
|
||||||
MaxConnsPerHost: 1024,
|
|
||||||
MaxIdleConnDuration: time.Minute,
|
|
||||||
ReadTimeout: 30 * time.Second,
|
|
||||||
WriteTimeout: 30 * time.Second,
|
|
||||||
DisableHeaderNamesNormalizing: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// HTTPClientConfig contains client settings
|
|
||||||
type HTTPClientConfig struct {
|
|
||||||
MaxTimeout time.Duration // Maximum timeout for requests (0 = no limit)
|
|
||||||
DefaultTimeout time.Duration // Default request timeout
|
|
||||||
MaxResponseSize int64 // Maximum response size in bytes (0 = no limit)
|
|
||||||
AllowRemote bool // Whether to allow remote connections
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultHTTPClientConfig provides sensible defaults
|
|
||||||
var DefaultHTTPClientConfig = HTTPClientConfig{
|
|
||||||
MaxTimeout: 60 * time.Second,
|
|
||||||
DefaultTimeout: 30 * time.Second,
|
|
||||||
MaxResponseSize: 10 * 1024 * 1024, // 10MB
|
|
||||||
AllowRemote: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewHTTPResponse creates a default HTTP response from pool
|
|
||||||
func NewHTTPResponse() *HTTPResponse {
|
|
||||||
return responsePool.Get().(*HTTPResponse)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReleaseResponse returns the response to the pool
|
|
||||||
func ReleaseResponse(resp *HTTPResponse) {
|
|
||||||
if resp == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear all values to prevent data leakage
|
|
||||||
resp.Status = 200 // Reset to default
|
|
||||||
|
|
||||||
// Clear headers
|
|
||||||
for k := range resp.Headers {
|
|
||||||
delete(resp.Headers, k)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear cookies
|
|
||||||
resp.Cookies = resp.Cookies[:0] // Keep capacity but set length to 0
|
|
||||||
|
|
||||||
// Reset session flag
|
|
||||||
resp.SessionModified = false
|
|
||||||
|
|
||||||
// Clear body
|
|
||||||
resp.Body = nil
|
|
||||||
|
|
||||||
responsePool.Put(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
// HTTPModuleInitFunc returns an initializer function for the HTTP module
|
|
||||||
func HTTPModuleInitFunc() func(*luajit.State) error {
|
|
||||||
return func(state *luajit.State) error {
|
|
||||||
// Register the native Go function first
|
|
||||||
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
|
|
||||||
logger.Error("[HTTP Module] Failed to register __http_request function: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up default HTTP client configuration
|
|
||||||
setupHTTPClientConfig(state)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupHTTPClientConfig configures HTTP client in Lua
|
|
||||||
func setupHTTPClientConfig(state *luajit.State) {
|
|
||||||
state.NewTable()
|
|
||||||
|
|
||||||
state.PushNumber(float64(DefaultHTTPClientConfig.MaxTimeout / time.Second))
|
|
||||||
state.SetField(-2, "max_timeout")
|
|
||||||
|
|
||||||
state.PushNumber(float64(DefaultHTTPClientConfig.DefaultTimeout / time.Second))
|
|
||||||
state.SetField(-2, "default_timeout")
|
|
||||||
|
|
||||||
state.PushNumber(float64(DefaultHTTPClientConfig.MaxResponseSize))
|
|
||||||
state.SetField(-2, "max_response_size")
|
|
||||||
|
|
||||||
state.PushBoolean(DefaultHTTPClientConfig.AllowRemote)
|
|
||||||
state.SetField(-2, "allow_remote")
|
|
||||||
|
|
||||||
state.SetGlobal("__http_client_config")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetHTTPResponse extracts the HTTP response from Lua state
|
|
||||||
func GetHTTPResponse(state *luajit.State) (*HTTPResponse, bool) {
|
|
||||||
response := NewHTTPResponse()
|
|
||||||
|
|
||||||
// Get response table
|
|
||||||
state.GetGlobal("__http_responses")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
ReleaseResponse(response)
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for response at thread index
|
|
||||||
state.PushNumber(1)
|
|
||||||
state.GetTable(-2)
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(2)
|
|
||||||
ReleaseResponse(response)
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get status
|
|
||||||
state.GetField(-1, "status")
|
|
||||||
if state.IsNumber(-1) {
|
|
||||||
response.Status = int(state.ToNumber(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get headers
|
|
||||||
state.GetField(-1, "headers")
|
|
||||||
if state.IsTable(-1) {
|
|
||||||
// Iterate through headers table
|
|
||||||
state.PushNil() // Start iteration
|
|
||||||
for state.Next(-2) {
|
|
||||||
// Stack has key at -2 and value at -1
|
|
||||||
if state.IsString(-2) && state.IsString(-1) {
|
|
||||||
key := state.ToString(-2)
|
|
||||||
value := state.ToString(-1)
|
|
||||||
response.Headers[key] = value
|
|
||||||
}
|
|
||||||
state.Pop(1) // Pop value, leave key for next iteration
|
|
||||||
}
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get cookies
|
|
||||||
state.GetField(-1, "cookies")
|
|
||||||
if state.IsTable(-1) {
|
|
||||||
// Iterate through cookies array
|
|
||||||
length := state.GetTableLength(-1)
|
|
||||||
for i := 1; i <= length; i++ {
|
|
||||||
state.PushNumber(float64(i))
|
|
||||||
state.GetTable(-2)
|
|
||||||
|
|
||||||
if state.IsTable(-1) {
|
|
||||||
cookie := extractCookie(state)
|
|
||||||
if cookie != nil {
|
|
||||||
response.Cookies = append(response.Cookies, cookie)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Check if session was modified
|
|
||||||
state.GetGlobal("__session_modified")
|
|
||||||
if state.IsBoolean(-1) && state.ToBoolean(-1) {
|
|
||||||
response.SessionModified = true
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Clean up
|
|
||||||
state.Pop(2) // Pop response table and __http_responses
|
|
||||||
|
|
||||||
return response, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ApplyHTTPResponse applies an HTTP response to a fasthttp.RequestCtx
|
|
||||||
func ApplyHTTPResponse(httpResp *HTTPResponse, ctx *fasthttp.RequestCtx) {
|
|
||||||
// Set status code
|
|
||||||
ctx.SetStatusCode(httpResp.Status)
|
|
||||||
|
|
||||||
// Set headers
|
|
||||||
for name, value := range httpResp.Headers {
|
|
||||||
ctx.Response.Header.Set(name, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set cookies
|
|
||||||
for _, cookie := range httpResp.Cookies {
|
|
||||||
ctx.Response.Header.SetCookie(cookie)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process the body based on its type
|
|
||||||
if httpResp.Body == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set body based on type
|
|
||||||
switch body := httpResp.Body.(type) {
|
|
||||||
case string:
|
|
||||||
ctx.SetBodyString(body)
|
|
||||||
case []byte:
|
|
||||||
ctx.SetBody(body)
|
|
||||||
case map[string]any, []any, []float64, []string, []int:
|
|
||||||
// Marshal JSON using a buffer from the pool
|
|
||||||
buf := bytebufferpool.Get()
|
|
||||||
defer bytebufferpool.Put(buf)
|
|
||||||
|
|
||||||
if err := json.NewEncoder(buf).Encode(body); err == nil {
|
|
||||||
// Set content type if not already set
|
|
||||||
if len(ctx.Response.Header.ContentType()) == 0 {
|
|
||||||
ctx.Response.Header.SetContentType("application/json")
|
|
||||||
}
|
|
||||||
ctx.SetBody(buf.Bytes())
|
|
||||||
} else {
|
|
||||||
// Fallback
|
|
||||||
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
// Default to string representation
|
|
||||||
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractCookie grabs cookies from the Lua state
|
|
||||||
func extractCookie(state *luajit.State) *fasthttp.Cookie {
|
|
||||||
cookie := fasthttp.AcquireCookie()
|
|
||||||
|
|
||||||
// Get name
|
|
||||||
state.GetField(-1, "name")
|
|
||||||
if !state.IsString(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
fasthttp.ReleaseCookie(cookie)
|
|
||||||
return nil // Name is required
|
|
||||||
}
|
|
||||||
cookie.SetKey(state.ToString(-1))
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get value
|
|
||||||
state.GetField(-1, "value")
|
|
||||||
if state.IsString(-1) {
|
|
||||||
cookie.SetValue(state.ToString(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get path
|
|
||||||
state.GetField(-1, "path")
|
|
||||||
if state.IsString(-1) {
|
|
||||||
cookie.SetPath(state.ToString(-1))
|
|
||||||
} else {
|
|
||||||
cookie.SetPath("/") // Default path
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get domain
|
|
||||||
state.GetField(-1, "domain")
|
|
||||||
if state.IsString(-1) {
|
|
||||||
cookie.SetDomain(state.ToString(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get expires
|
|
||||||
state.GetField(-1, "expires")
|
|
||||||
if state.IsNumber(-1) {
|
|
||||||
expiry := int64(state.ToNumber(-1))
|
|
||||||
cookie.SetExpire(time.Unix(expiry, 0))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get max age
|
|
||||||
state.GetField(-1, "max_age")
|
|
||||||
if state.IsNumber(-1) {
|
|
||||||
cookie.SetMaxAge(int(state.ToNumber(-1)))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get secure
|
|
||||||
state.GetField(-1, "secure")
|
|
||||||
if state.IsBoolean(-1) {
|
|
||||||
cookie.SetSecure(state.ToBoolean(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get http only
|
|
||||||
state.GetField(-1, "http_only")
|
|
||||||
if state.IsBoolean(-1) {
|
|
||||||
cookie.SetHTTPOnly(state.ToBoolean(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
return cookie
|
|
||||||
}
|
|
||||||
|
|
||||||
// httpRequest makes an HTTP request and returns the result to Lua
|
|
||||||
func httpRequest(state *luajit.State) int {
|
|
||||||
// Get method (required)
|
|
||||||
if !state.IsString(1) {
|
|
||||||
state.PushString("http.client.request: method must be a string")
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
method := strings.ToUpper(state.ToString(1))
|
|
||||||
|
|
||||||
// Get URL (required)
|
|
||||||
if !state.IsString(2) {
|
|
||||||
state.PushString("http.client.request: url must be a string")
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
urlStr := state.ToString(2)
|
|
||||||
|
|
||||||
// Parse URL to check if it's valid and if it's allowed
|
|
||||||
parsedURL, err := url.Parse(urlStr)
|
|
||||||
if err != nil {
|
|
||||||
state.PushString("Invalid URL: " + err.Error())
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get client configuration
|
|
||||||
var config HTTPClientConfig = DefaultHTTPClientConfig
|
|
||||||
state.GetGlobal("__http_client_config")
|
|
||||||
if !state.IsNil(-1) && state.IsTable(-1) {
|
|
||||||
// Extract max timeout
|
|
||||||
state.GetField(-1, "max_timeout")
|
|
||||||
if state.IsNumber(-1) {
|
|
||||||
config.MaxTimeout = time.Duration(state.ToNumber(-1)) * time.Second
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Extract default timeout
|
|
||||||
state.GetField(-1, "default_timeout")
|
|
||||||
if state.IsNumber(-1) {
|
|
||||||
config.DefaultTimeout = time.Duration(state.ToNumber(-1)) * time.Second
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Extract max response size
|
|
||||||
state.GetField(-1, "max_response_size")
|
|
||||||
if state.IsNumber(-1) {
|
|
||||||
config.MaxResponseSize = int64(state.ToNumber(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Extract allow remote
|
|
||||||
state.GetField(-1, "allow_remote")
|
|
||||||
if state.IsBoolean(-1) {
|
|
||||||
config.AllowRemote = state.ToBoolean(-1)
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Check if remote connections are allowed
|
|
||||||
if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") {
|
|
||||||
state.PushString("Remote connections are not allowed")
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use bytebufferpool for request and response
|
|
||||||
req := fasthttp.AcquireRequest()
|
|
||||||
resp := fasthttp.AcquireResponse()
|
|
||||||
defer fasthttp.ReleaseRequest(req)
|
|
||||||
defer fasthttp.ReleaseResponse(resp)
|
|
||||||
|
|
||||||
// Set up request
|
|
||||||
req.Header.SetMethod(method)
|
|
||||||
req.SetRequestURI(urlStr)
|
|
||||||
req.Header.Set("User-Agent", "Moonshark/1.0")
|
|
||||||
|
|
||||||
// Get body (optional)
|
|
||||||
if state.GetTop() >= 3 && !state.IsNil(3) {
|
|
||||||
if state.IsString(3) {
|
|
||||||
// String body
|
|
||||||
req.SetBodyString(state.ToString(3))
|
|
||||||
} else if state.IsTable(3) {
|
|
||||||
// Table body - convert to JSON
|
|
||||||
luaTable, err := state.ToTable(3)
|
|
||||||
if err != nil {
|
|
||||||
state.PushString("Failed to parse body table: " + err.Error())
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use bytebufferpool for JSON serialization
|
|
||||||
buf := bytebufferpool.Get()
|
|
||||||
defer bytebufferpool.Put(buf)
|
|
||||||
|
|
||||||
if err := json.NewEncoder(buf).Encode(luaTable); err != nil {
|
|
||||||
state.PushString("Failed to convert body to JSON: " + err.Error())
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
req.SetBody(buf.Bytes())
|
|
||||||
} else {
|
|
||||||
state.PushString("Body must be a string or table")
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process options (headers, timeout, etc.)
|
|
||||||
timeout := config.DefaultTimeout
|
|
||||||
if state.GetTop() >= 4 && !state.IsNil(4) {
|
|
||||||
if !state.IsTable(4) {
|
|
||||||
state.PushString("Options must be a table")
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process headers
|
|
||||||
state.GetField(4, "headers")
|
|
||||||
if state.IsTable(-1) {
|
|
||||||
// Iterate through headers
|
|
||||||
state.PushNil() // Start iteration
|
|
||||||
for state.Next(-2) {
|
|
||||||
// Stack now has key at -2 and value at -1
|
|
||||||
if state.IsString(-2) && state.IsString(-1) {
|
|
||||||
headerName := state.ToString(-2)
|
|
||||||
headerValue := state.ToString(-1)
|
|
||||||
req.Header.Set(headerName, headerValue)
|
|
||||||
}
|
|
||||||
state.Pop(1) // Pop value, leave key for next iteration
|
|
||||||
}
|
|
||||||
}
|
|
||||||
state.Pop(1) // Pop headers table
|
|
||||||
|
|
||||||
// Get timeout
|
|
||||||
state.GetField(4, "timeout")
|
|
||||||
if state.IsNumber(-1) {
|
|
||||||
requestTimeout := time.Duration(state.ToNumber(-1)) * time.Second
|
|
||||||
|
|
||||||
// Apply max timeout if configured
|
|
||||||
if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout {
|
|
||||||
timeout = config.MaxTimeout
|
|
||||||
} else {
|
|
||||||
timeout = requestTimeout
|
|
||||||
}
|
|
||||||
}
|
|
||||||
state.Pop(1) // Pop timeout
|
|
||||||
|
|
||||||
// Set content type for POST/PUT if body is present and content-type not manually set
|
|
||||||
if (method == "POST" || method == "PUT") && req.Body() != nil && req.Header.Peek("Content-Type") == nil {
|
|
||||||
// Check if options specify content type
|
|
||||||
state.GetField(4, "content_type")
|
|
||||||
if state.IsString(-1) {
|
|
||||||
req.Header.Set("Content-Type", state.ToString(-1))
|
|
||||||
} else {
|
|
||||||
// Default to JSON if body is a table, otherwise plain text
|
|
||||||
if state.IsTable(3) {
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
} else {
|
|
||||||
req.Header.Set("Content-Type", "text/plain")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
state.Pop(1) // Pop content_type
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process query parameters
|
|
||||||
state.GetField(4, "query")
|
|
||||||
if state.IsTable(-1) {
|
|
||||||
// Create URL args
|
|
||||||
args := req.URI().QueryArgs()
|
|
||||||
|
|
||||||
// Iterate through query params
|
|
||||||
state.PushNil() // Start iteration
|
|
||||||
for state.Next(-2) {
|
|
||||||
// Stack now has key at -2 and value at -1
|
|
||||||
if state.IsString(-2) {
|
|
||||||
paramName := state.ToString(-2)
|
|
||||||
|
|
||||||
// Handle different value types
|
|
||||||
if state.IsString(-1) {
|
|
||||||
args.Add(paramName, state.ToString(-1))
|
|
||||||
} else if state.IsNumber(-1) {
|
|
||||||
args.Add(paramName, strings.TrimRight(strings.TrimRight(
|
|
||||||
state.ToString(-1), "0"), "."))
|
|
||||||
} else if state.IsBoolean(-1) {
|
|
||||||
if state.ToBoolean(-1) {
|
|
||||||
args.Add(paramName, "true")
|
|
||||||
} else {
|
|
||||||
args.Add(paramName, "false")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
state.Pop(1) // Pop value, leave key for next iteration
|
|
||||||
}
|
|
||||||
}
|
|
||||||
state.Pop(1) // Pop query table
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create context with timeout
|
|
||||||
_, cancel := context.WithTimeout(context.Background(), timeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Execute request
|
|
||||||
err = defaultFastClient.DoTimeout(req, resp, timeout)
|
|
||||||
if err != nil {
|
|
||||||
errStr := "Request failed: " + err.Error()
|
|
||||||
if errors.Is(err, fasthttp.ErrTimeout) {
|
|
||||||
errStr = "Request timed out after " + timeout.String()
|
|
||||||
}
|
|
||||||
state.PushString(errStr)
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create response table
|
|
||||||
state.NewTable()
|
|
||||||
|
|
||||||
// Set status code
|
|
||||||
state.PushNumber(float64(resp.StatusCode()))
|
|
||||||
state.SetField(-2, "status")
|
|
||||||
|
|
||||||
// Set status text
|
|
||||||
statusText := fasthttp.StatusMessage(resp.StatusCode())
|
|
||||||
state.PushString(statusText)
|
|
||||||
state.SetField(-2, "status_text")
|
|
||||||
|
|
||||||
// Set body
|
|
||||||
var respBody []byte
|
|
||||||
|
|
||||||
// Apply size limits to response
|
|
||||||
if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize {
|
|
||||||
// Make a limited copy
|
|
||||||
respBody = make([]byte, config.MaxResponseSize)
|
|
||||||
copy(respBody, resp.Body())
|
|
||||||
} else {
|
|
||||||
respBody = resp.Body()
|
|
||||||
}
|
|
||||||
|
|
||||||
state.PushString(string(respBody))
|
|
||||||
state.SetField(-2, "body")
|
|
||||||
|
|
||||||
// Parse body as JSON if content type is application/json
|
|
||||||
contentType := string(resp.Header.ContentType())
|
|
||||||
if strings.Contains(contentType, "application/json") {
|
|
||||||
var jsonData any
|
|
||||||
if err := json.Unmarshal(respBody, &jsonData); err == nil {
|
|
||||||
if err := state.PushValue(jsonData); err == nil {
|
|
||||||
state.SetField(-2, "json")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set headers
|
|
||||||
state.NewTable()
|
|
||||||
resp.Header.VisitAll(func(key, value []byte) {
|
|
||||||
state.PushString(string(value))
|
|
||||||
state.SetField(-2, string(key))
|
|
||||||
})
|
|
||||||
state.SetField(-2, "headers")
|
|
||||||
|
|
||||||
// Create ok field (true if status code is 2xx)
|
|
||||||
state.PushBoolean(resp.StatusCode() >= 200 && resp.StatusCode() < 300)
|
|
||||||
state.SetField(-2, "ok")
|
|
||||||
|
|
||||||
return 1
|
|
||||||
}
|
|
|
@ -1,84 +0,0 @@
|
||||||
package sandbox
|
|
||||||
|
|
||||||
import (
|
|
||||||
"Moonshark/core/utils/logger"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ModuleFunc returns a map of module functions
|
|
||||||
type ModuleFunc func() map[string]luajit.GoFunction
|
|
||||||
|
|
||||||
// StateInitFunc initializes a module in a Lua state
|
|
||||||
type StateInitFunc func(*luajit.State) error
|
|
||||||
|
|
||||||
// RegisterModule registers a map of functions as a Lua module
|
|
||||||
func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.GoFunction) error {
|
|
||||||
// Create a new table for the module
|
|
||||||
state.NewTable()
|
|
||||||
|
|
||||||
// Add each function to the module table
|
|
||||||
for fname, f := range funcs {
|
|
||||||
state.PushString(fname)
|
|
||||||
if err := state.PushGoFunction(f); err != nil {
|
|
||||||
state.Pop(1) // Pop table
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
state.SetTable(-3)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register the module globally
|
|
||||||
state.SetGlobal(name)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ModuleInitFunc creates a state initializer that registers multiple modules
|
|
||||||
func ModuleInitFunc(modules map[string]ModuleFunc) StateInitFunc {
|
|
||||||
return func(state *luajit.State) error {
|
|
||||||
for name, moduleFunc := range modules {
|
|
||||||
if err := RegisterModule(state, name, moduleFunc()); err != nil {
|
|
||||||
logger.Error("Failed to register module %s: %v", name, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CombineInitFuncs combines multiple state initializer functions into one
|
|
||||||
func CombineInitFuncs(funcs ...StateInitFunc) StateInitFunc {
|
|
||||||
return func(state *luajit.State) error {
|
|
||||||
for _, f := range funcs {
|
|
||||||
if f != nil {
|
|
||||||
if err := f(state); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterLuaCode registers a Lua code snippet in a state
|
|
||||||
func RegisterLuaCode(state *luajit.State, code string) error {
|
|
||||||
return state.DoString(code)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterLuaCodeInitFunc returns a StateInitFunc that registers Lua code
|
|
||||||
func RegisterLuaCodeInitFunc(code string) StateInitFunc {
|
|
||||||
return func(state *luajit.State) error {
|
|
||||||
return RegisterLuaCode(state, code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterLuaModuleInitFunc returns a StateInitFunc that registers a Lua module
|
|
||||||
func RegisterLuaModuleInitFunc(name string, code string) StateInitFunc {
|
|
||||||
return func(state *luajit.State) error {
|
|
||||||
// Create name = {} global
|
|
||||||
state.NewTable()
|
|
||||||
state.SetGlobal(name)
|
|
||||||
|
|
||||||
// Then run the module code which will populate it
|
|
||||||
return state.DoString(code)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,371 +0,0 @@
|
||||||
package sandbox
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/goccy/go-json"
|
|
||||||
"github.com/valyala/bytebufferpool"
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
|
|
||||||
luaCtx "Moonshark/core/runner/context"
|
|
||||||
"Moonshark/core/sessions"
|
|
||||||
"Moonshark/core/utils/logger"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Global bytecode cache to improve performance
|
|
||||||
var (
|
|
||||||
sandboxBytecode []byte
|
|
||||||
bytecodeOnce sync.Once
|
|
||||||
)
|
|
||||||
|
|
||||||
// precompileSandbox compiles the sandbox.lua code to bytecode once
|
|
||||||
func precompileSandbox() {
|
|
||||||
tempState := luajit.New()
|
|
||||||
if tempState == nil {
|
|
||||||
logger.Error("Failed to create temporary Lua state for bytecode compilation")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer tempState.Close()
|
|
||||||
defer tempState.Cleanup()
|
|
||||||
|
|
||||||
var err error
|
|
||||||
sandboxBytecode, err = tempState.CompileBytecode(sandboxLua, "sandbox.lua")
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to precompile sandbox.lua: %v", err)
|
|
||||||
} else {
|
|
||||||
logger.Debug("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(sandboxBytecode))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sandbox provides a secure execution environment for Lua scripts
|
|
||||||
type Sandbox struct {
|
|
||||||
modules map[string]any // Custom modules for environment
|
|
||||||
debug bool // Enable debug output
|
|
||||||
mu sync.RWMutex // Protects modules
|
|
||||||
initializers *ModuleInitializers // Module initializers
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSandbox creates a new sandbox environment
|
|
||||||
func NewSandbox() *Sandbox {
|
|
||||||
return &Sandbox{
|
|
||||||
modules: make(map[string]any, 8), // Pre-allocate with reasonable capacity
|
|
||||||
debug: false,
|
|
||||||
initializers: DefaultInitializers(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// EnableDebug turns on debug logging
|
|
||||||
func (s *Sandbox) EnableDebug() {
|
|
||||||
s.debug = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// debugLog logs a message if debug mode is enabled
|
|
||||||
func (s *Sandbox) debugLog(format string, args ...interface{}) {
|
|
||||||
if s.debug {
|
|
||||||
logger.Debug("Sandbox "+format, args...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// debugLogCont logs a continuation message if debug mode is enabled
|
|
||||||
func (s *Sandbox) debugLogCont(format string, args ...interface{}) {
|
|
||||||
if s.debug {
|
|
||||||
logger.DebugCont(format, args...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddModule adds a module to the sandbox environment
|
|
||||||
func (s *Sandbox) AddModule(name string, module any) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
s.modules[name] = module
|
|
||||||
s.debugLog("Added module: %s", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup initializes the sandbox in a Lua state
|
|
||||||
func (s *Sandbox) Setup(state *luajit.State, stateIndex int) error {
|
|
||||||
verbose := stateIndex == 0
|
|
||||||
|
|
||||||
if verbose {
|
|
||||||
s.debugLog("Setting up sandbox...")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize modules with the embedded sandbox code
|
|
||||||
if err := InitializeAll(state, s.initializers); err != nil {
|
|
||||||
if verbose {
|
|
||||||
s.debugLog("Failed to initialize sandbox: %v", err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register custom modules in the global environment
|
|
||||||
s.mu.RLock()
|
|
||||||
for name, module := range s.modules {
|
|
||||||
if verbose {
|
|
||||||
s.debugLog("Registering module: %s", name)
|
|
||||||
}
|
|
||||||
if err := state.PushValue(module); err != nil {
|
|
||||||
s.mu.RUnlock()
|
|
||||||
if verbose {
|
|
||||||
s.debugLog("Failed to register module %s: %v", name, err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
state.SetGlobal(name)
|
|
||||||
}
|
|
||||||
s.mu.RUnlock()
|
|
||||||
|
|
||||||
if verbose {
|
|
||||||
s.debugLogCont("Sandbox setup complete")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute runs bytecode in the sandbox
|
|
||||||
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx map[string]any) (any, error) {
|
|
||||||
// Create a temporary context if we only have a map
|
|
||||||
if ctx != nil {
|
|
||||||
tempCtx := &luaCtx.Context{
|
|
||||||
Values: ctx,
|
|
||||||
}
|
|
||||||
return s.OptimizedExecute(state, bytecode, tempCtx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Just pass nil through if we have no context
|
|
||||||
return s.OptimizedExecute(state, bytecode, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OptimizedExecute runs bytecode with a fasthttp context if available
|
|
||||||
func (s *Sandbox) OptimizedExecute(state *luajit.State, bytecode []byte, ctx *luaCtx.Context) (any, error) {
|
|
||||||
// Use a buffer from the pool for any string operations
|
|
||||||
buf := bytebufferpool.Get()
|
|
||||||
defer bytebufferpool.Put(buf)
|
|
||||||
|
|
||||||
// Load bytecode
|
|
||||||
if err := state.LoadBytecode(bytecode, "script"); err != nil {
|
|
||||||
s.debugLog("Failed to load bytecode: %v", err)
|
|
||||||
return nil, fmt.Errorf("failed to load script: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare context values
|
|
||||||
var ctxValues map[string]any
|
|
||||||
if ctx != nil {
|
|
||||||
ctxValues = ctx.Values
|
|
||||||
} else {
|
|
||||||
ctxValues = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize session tracking in Lua
|
|
||||||
if err := state.DoString("__session_data = {}; __session_modified = false"); err != nil {
|
|
||||||
s.debugLog("Failed to initialize session data: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load session data if available
|
|
||||||
if ctx != nil && ctx.Session != nil {
|
|
||||||
// Set session ID in Lua
|
|
||||||
sessionIDCode := fmt.Sprintf("__session_id = %q", ctx.Session.ID)
|
|
||||||
if err := state.DoString(sessionIDCode); err != nil {
|
|
||||||
s.debugLog("Failed to set session ID: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get session data and populate Lua table
|
|
||||||
state.GetGlobal("__session_data")
|
|
||||||
if state.IsTable(-1) {
|
|
||||||
sessionData := ctx.Session.GetAll()
|
|
||||||
for k, v := range sessionData {
|
|
||||||
state.PushString(k)
|
|
||||||
if err := state.PushValue(v); err != nil {
|
|
||||||
s.debugLog("Failed to push session value %s: %v", k, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
state.SetTable(-3)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
state.Pop(1) // Pop __session_data
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare context table
|
|
||||||
if ctxValues != nil {
|
|
||||||
state.CreateTable(0, len(ctxValues))
|
|
||||||
for k, v := range ctxValues {
|
|
||||||
state.PushString(k)
|
|
||||||
if err := state.PushValue(v); err != nil {
|
|
||||||
state.Pop(2) // Pop key and table
|
|
||||||
s.debugLog("Failed to push context value %s: %v", k, err)
|
|
||||||
return nil, fmt.Errorf("failed to prepare context: %w", err)
|
|
||||||
}
|
|
||||||
state.SetTable(-3)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
state.PushNil() // No context
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get execution function
|
|
||||||
state.GetGlobal("__execute_script")
|
|
||||||
if !state.IsFunction(-1) {
|
|
||||||
state.Pop(2) // Pop context and non-function
|
|
||||||
s.debugLog("__execute_script is not a function")
|
|
||||||
return nil, fmt.Errorf("sandbox execution function not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stack setup for call: __execute_script, bytecode function, context
|
|
||||||
state.PushCopy(-3) // bytecode function (copy from -3)
|
|
||||||
state.PushCopy(-3) // context (copy from -3)
|
|
||||||
|
|
||||||
// Clean up duplicate references
|
|
||||||
state.Remove(-5) // Remove original bytecode function
|
|
||||||
state.Remove(-4) // Remove original context
|
|
||||||
|
|
||||||
// Call with 2 args (function, context), 1 result
|
|
||||||
if err := state.Call(2, 1); err != nil {
|
|
||||||
s.debugLog("Execution failed: %v", err)
|
|
||||||
return nil, fmt.Errorf("script execution failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get result
|
|
||||||
result, err := state.ToValue(-1)
|
|
||||||
state.Pop(1) // Pop result
|
|
||||||
|
|
||||||
// Extract session data if it was modified
|
|
||||||
if ctx != nil && ctx.Session != nil {
|
|
||||||
// Check if session was modified
|
|
||||||
state.GetGlobal("__session_modified")
|
|
||||||
if state.IsBoolean(-1) && state.ToBoolean(-1) {
|
|
||||||
ctx.SessionModified = true
|
|
||||||
|
|
||||||
// Extract session data
|
|
||||||
state.GetGlobal("__session_data")
|
|
||||||
if state.IsTable(-1) {
|
|
||||||
// Clear existing data and extract new data from Lua
|
|
||||||
sessionData := make(map[string]any)
|
|
||||||
|
|
||||||
// Extract new session data
|
|
||||||
state.PushNil() // Start iteration
|
|
||||||
for state.Next(-2) {
|
|
||||||
// Stack now has key at -2 and value at -1
|
|
||||||
if state.IsString(-2) {
|
|
||||||
key := state.ToString(-2)
|
|
||||||
value, err := state.ToValue(-1)
|
|
||||||
if err == nil {
|
|
||||||
sessionData[key] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
state.Pop(1) // Pop value, leave key for next iteration
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update session with the new data
|
|
||||||
for k, v := range sessionData {
|
|
||||||
if err := ctx.Session.Set(k, v); err != nil {
|
|
||||||
s.debugLog("Failed to set session value %s: %v", k, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
state.Pop(1) // Pop __session_data
|
|
||||||
}
|
|
||||||
state.Pop(1) // Pop __session_modified
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for HTTP response
|
|
||||||
httpResponse, hasResponse := GetHTTPResponse(state)
|
|
||||||
if hasResponse {
|
|
||||||
// Add the script result as the response body
|
|
||||||
httpResponse.Body = result
|
|
||||||
|
|
||||||
// Mark session as modified if needed
|
|
||||||
if ctx != nil && ctx.SessionModified {
|
|
||||||
httpResponse.SessionModified = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we have a fasthttp context, apply the response directly
|
|
||||||
if ctx != nil && ctx.RequestCtx != nil {
|
|
||||||
// If session was modified, save it
|
|
||||||
if ctx.SessionModified && ctx.Session != nil {
|
|
||||||
// Save session and set cookie if needed
|
|
||||||
sessions.GlobalSessionManager.SaveSession(ctx.Session)
|
|
||||||
|
|
||||||
// Add session cookie to the response
|
|
||||||
cookieOpts := sessions.GlobalSessionManager.CookieOptions()
|
|
||||||
cookie := fasthttp.AcquireCookie()
|
|
||||||
cookie.SetKey(cookieOpts["name"].(string))
|
|
||||||
cookie.SetValue(ctx.Session.ID)
|
|
||||||
cookie.SetPath(cookieOpts["path"].(string))
|
|
||||||
|
|
||||||
if domain, ok := cookieOpts["domain"].(string); ok && domain != "" {
|
|
||||||
cookie.SetDomain(domain)
|
|
||||||
}
|
|
||||||
|
|
||||||
if maxAge, ok := cookieOpts["max_age"].(int); ok && maxAge > 0 {
|
|
||||||
cookie.SetMaxAge(maxAge)
|
|
||||||
}
|
|
||||||
|
|
||||||
cookie.SetSecure(cookieOpts["secure"].(bool))
|
|
||||||
cookie.SetHTTPOnly(cookieOpts["http_only"].(bool))
|
|
||||||
|
|
||||||
// Add to response cookies
|
|
||||||
httpResponse.Cookies = append(httpResponse.Cookies, cookie)
|
|
||||||
}
|
|
||||||
|
|
||||||
ApplyHTTPResponse(httpResponse, ctx.RequestCtx)
|
|
||||||
ReleaseResponse(httpResponse)
|
|
||||||
return nil, nil // No need to return response object
|
|
||||||
}
|
|
||||||
|
|
||||||
return httpResponse, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we have a fasthttp context and the result needs to be written directly
|
|
||||||
if ctx != nil && ctx.RequestCtx != nil && (result != nil) {
|
|
||||||
// For direct HTTP responses
|
|
||||||
switch r := result.(type) {
|
|
||||||
case string:
|
|
||||||
ctx.RequestCtx.SetBodyString(r)
|
|
||||||
case []byte:
|
|
||||||
ctx.RequestCtx.SetBody(r)
|
|
||||||
case map[string]any, []any:
|
|
||||||
// JSON response
|
|
||||||
ctx.RequestCtx.Response.Header.SetContentType("application/json")
|
|
||||||
if err := json.NewEncoder(buf).Encode(r); err == nil {
|
|
||||||
ctx.RequestCtx.SetBody(buf.Bytes())
|
|
||||||
} else {
|
|
||||||
ctx.RequestCtx.SetBodyString(fmt.Sprintf("%v", r))
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
// Default string conversion
|
|
||||||
ctx.RequestCtx.SetBodyString(fmt.Sprintf("%v", r))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle session if modified
|
|
||||||
if ctx.SessionModified && ctx.Session != nil {
|
|
||||||
// Save session
|
|
||||||
sessions.GlobalSessionManager.SaveSession(ctx.Session)
|
|
||||||
|
|
||||||
// Add session cookie
|
|
||||||
cookieOpts := sessions.GlobalSessionManager.CookieOptions()
|
|
||||||
cookie := fasthttp.AcquireCookie()
|
|
||||||
cookie.SetKey(cookieOpts["name"].(string))
|
|
||||||
cookie.SetValue(ctx.Session.ID)
|
|
||||||
cookie.SetPath(cookieOpts["path"].(string))
|
|
||||||
|
|
||||||
if domain, ok := cookieOpts["domain"].(string); ok && domain != "" {
|
|
||||||
cookie.SetDomain(domain)
|
|
||||||
}
|
|
||||||
|
|
||||||
if maxAge, ok := cookieOpts["max_age"].(int); ok && maxAge > 0 {
|
|
||||||
cookie.SetMaxAge(maxAge)
|
|
||||||
}
|
|
||||||
|
|
||||||
cookie.SetSecure(cookieOpts["secure"].(bool))
|
|
||||||
cookie.SetHTTPOnly(cookieOpts["http_only"].(bool))
|
|
||||||
|
|
||||||
// Add to response
|
|
||||||
ctx.RequestCtx.Response.Header.SetCookie(cookie)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, err
|
|
||||||
}
|
|
|
@ -1,58 +0,0 @@
|
||||||
package sandbox
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
|
|
||||||
"Moonshark/core/utils/logger"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
||||||
)
|
|
||||||
|
|
||||||
// UtilModuleInitFunc returns an initializer for the util module
|
|
||||||
func UtilModuleInitFunc() func(*luajit.State) error {
|
|
||||||
return func(state *luajit.State) error {
|
|
||||||
return RegisterModule(state, "util", UtilModuleFunctions())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// UtilModuleFunctions returns all functions for the util module
|
|
||||||
func UtilModuleFunctions() map[string]luajit.GoFunction {
|
|
||||||
return map[string]luajit.GoFunction{
|
|
||||||
"generate_token": GenerateToken,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateToken creates a cryptographically secure random token
|
|
||||||
func GenerateToken(s *luajit.State) int {
|
|
||||||
// Get the length from the Lua arguments (default to 32)
|
|
||||||
length := 32
|
|
||||||
if s.GetTop() >= 1 && s.IsNumber(1) {
|
|
||||||
length = int(s.ToNumber(1))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enforce minimum length for security
|
|
||||||
if length < 16 {
|
|
||||||
length = 16
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate secure random bytes
|
|
||||||
tokenBytes := make([]byte, length)
|
|
||||||
if _, err := rand.Read(tokenBytes); err != nil {
|
|
||||||
s.PushString("")
|
|
||||||
logger.Error("Failed to generate secure token: %v", err)
|
|
||||||
return 1 // Return empty string on error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encode as base64
|
|
||||||
token := base64.RawURLEncoding.EncodeToString(tokenBytes)
|
|
||||||
|
|
||||||
// Trim to requested length (base64 might be longer)
|
|
||||||
if len(token) > length {
|
|
||||||
token = token[:length]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Push the token to the Lua stack
|
|
||||||
s.PushString(token)
|
|
||||||
return 1 // One return value
|
|
||||||
}
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
|
|
||||||
"github.com/VictoriaMetrics/fastcache"
|
"github.com/VictoriaMetrics/fastcache"
|
||||||
"github.com/goccy/go-json"
|
"github.com/goccy/go-json"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -75,7 +76,7 @@ func (sm *SessionManager) GetSession(id string) *Session {
|
||||||
|
|
||||||
// Store back with updated timestamp
|
// Store back with updated timestamp
|
||||||
updatedData, _ := json.Marshal(session)
|
updatedData, _ := json.Marshal(session)
|
||||||
sm.cache.Set([]byte(id), updatedData) // Use updatedData, not data
|
sm.cache.Set([]byte(id), updatedData)
|
||||||
|
|
||||||
return session
|
return session
|
||||||
}
|
}
|
||||||
|
@ -141,5 +142,39 @@ func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, ht
|
||||||
sm.cookieMaxAge = maxAge
|
sm.cookieMaxAge = maxAge
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSessionFromRequest extracts the session from a request context
|
||||||
|
func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session {
|
||||||
|
cookie := ctx.Request.Header.Cookie(sm.cookieName)
|
||||||
|
if len(cookie) == 0 {
|
||||||
|
// No session cookie, create a new session
|
||||||
|
return sm.CreateSession()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Session cookie exists, get the session
|
||||||
|
return sm.GetSession(string(cookie))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveSessionToResponse adds the session cookie to an HTTP response
|
||||||
|
func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) {
|
||||||
|
cookie := fasthttp.AcquireCookie()
|
||||||
|
defer fasthttp.ReleaseCookie(cookie)
|
||||||
|
|
||||||
|
sm.mu.RLock()
|
||||||
|
cookie.SetKey(sm.cookieName)
|
||||||
|
cookie.SetValue(session.ID)
|
||||||
|
cookie.SetPath(sm.cookiePath)
|
||||||
|
cookie.SetHTTPOnly(sm.cookieHTTPOnly)
|
||||||
|
cookie.SetMaxAge(sm.cookieMaxAge)
|
||||||
|
|
||||||
|
if sm.cookieDomain != "" {
|
||||||
|
cookie.SetDomain(sm.cookieDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
cookie.SetSecure(sm.cookieSecure)
|
||||||
|
sm.mu.RUnlock()
|
||||||
|
|
||||||
|
ctx.Response.Header.SetCookie(cookie)
|
||||||
|
}
|
||||||
|
|
||||||
// GlobalSessionManager is the default session manager instance
|
// GlobalSessionManager is the default session manager instance
|
||||||
var GlobalSessionManager = NewSessionManager()
|
var GlobalSessionManager = NewSessionManager()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user