Compare commits
2 Commits
2a2ffc9cc5
...
82c588336d
Author | SHA1 | Date | |
---|---|---|---|
82c588336d | |||
21559bd6b7 |
15
Moonshark.go
15
Moonshark.go
|
@ -11,6 +11,7 @@ import (
|
||||||
"git.sharkk.net/Sky/Moonshark/core/logger"
|
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||||
"git.sharkk.net/Sky/Moonshark/core/routers"
|
"git.sharkk.net/Sky/Moonshark/core/routers"
|
||||||
"git.sharkk.net/Sky/Moonshark/core/runner"
|
"git.sharkk.net/Sky/Moonshark/core/runner"
|
||||||
|
"git.sharkk.net/Sky/Moonshark/core/sessions"
|
||||||
"git.sharkk.net/Sky/Moonshark/core/utils"
|
"git.sharkk.net/Sky/Moonshark/core/utils"
|
||||||
"git.sharkk.net/Sky/Moonshark/core/watchers"
|
"git.sharkk.net/Sky/Moonshark/core/watchers"
|
||||||
)
|
)
|
||||||
|
@ -166,10 +167,24 @@ func (s *Moonshark) initRunner() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize session manager
|
||||||
|
sessionManager := sessions.GlobalSessionManager
|
||||||
|
|
||||||
|
// Configure session cookies
|
||||||
|
sessionManager.SetCookieOptions(
|
||||||
|
"MSESSID", // name
|
||||||
|
"/", // path
|
||||||
|
"", // domain
|
||||||
|
false, // secure
|
||||||
|
true, // httpOnly
|
||||||
|
86400, // maxAge (1 day)
|
||||||
|
)
|
||||||
|
|
||||||
// Set up runner options
|
// Set up runner options
|
||||||
runnerOpts := []runner.RunnerOption{
|
runnerOpts := []runner.RunnerOption{
|
||||||
runner.WithPoolSize(s.Config.PoolSize),
|
runner.WithPoolSize(s.Config.PoolSize),
|
||||||
runner.WithLibDirs(s.Config.LibDirs...),
|
runner.WithLibDirs(s.Config.LibDirs...),
|
||||||
|
runner.WithSessionManager(sessionManager),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add debug option conditionally
|
// Add debug option conditionally
|
||||||
|
|
|
@ -164,6 +164,15 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
|
||||||
// Log bytecode size
|
// Log bytecode size
|
||||||
logger.Debug("Executing Lua route with %d bytes of bytecode", len(bytecode))
|
logger.Debug("Executing Lua route with %d bytes of bytecode", len(bytecode))
|
||||||
|
|
||||||
|
// Extract cookies instead of storing the raw request
|
||||||
|
cookieMap := make(map[string]any)
|
||||||
|
for _, cookie := range r.Cookies() {
|
||||||
|
cookieMap[cookie.Name] = cookie.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store cookie map instead of raw request
|
||||||
|
ctx.Set("_request_cookies", cookieMap)
|
||||||
|
|
||||||
// Add request info directly to context
|
// Add request info directly to context
|
||||||
ctx.Set("method", r.Method)
|
ctx.Set("method", r.Method)
|
||||||
ctx.Set("path", r.URL.Path)
|
ctx.Set("path", r.URL.Path)
|
||||||
|
|
|
@ -37,6 +37,12 @@ type StateWrapper struct {
|
||||||
index int // Index for debugging
|
index int // Index for debugging
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InitHook is a function that runs before executing a script
|
||||||
|
type InitHook func(*luajit.State, *Context) error
|
||||||
|
|
||||||
|
// FinalizeHook is a function that runs after executing a script
|
||||||
|
type FinalizeHook func(*luajit.State, *Context, any) error
|
||||||
|
|
||||||
// LuaRunner runs Lua scripts using a pool of Lua states
|
// LuaRunner runs Lua scripts using a pool of Lua states
|
||||||
type LuaRunner struct {
|
type LuaRunner struct {
|
||||||
states []*StateWrapper // Pool of Lua states
|
states []*StateWrapper // Pool of Lua states
|
||||||
|
@ -47,6 +53,8 @@ type LuaRunner struct {
|
||||||
isRunning atomic.Bool // Flag indicating if the runner is active
|
isRunning atomic.Bool // Flag indicating if the runner is active
|
||||||
mu sync.RWMutex // Mutex for thread safety
|
mu sync.RWMutex // Mutex for thread safety
|
||||||
debug bool // Enable debug logging
|
debug bool // Enable debug logging
|
||||||
|
initHooks []InitHook // Hooks to run before script execution
|
||||||
|
finalizeHooks []FinalizeHook // Hooks to run after script execution
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithPoolSize sets the state pool size
|
// WithPoolSize sets the state pool size
|
||||||
|
@ -91,6 +99,8 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
|
||||||
runner := &LuaRunner{
|
runner := &LuaRunner{
|
||||||
poolSize: runtime.GOMAXPROCS(0),
|
poolSize: runtime.GOMAXPROCS(0),
|
||||||
debug: false,
|
debug: false,
|
||||||
|
initHooks: make([]InitHook, 0),
|
||||||
|
finalizeHooks: make([]FinalizeHook, 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply options
|
// Apply options
|
||||||
|
@ -208,6 +218,20 @@ func (r *LuaRunner) initState(index int) (*StateWrapper, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddInitHook adds a hook to be called before script execution
|
||||||
|
func (r *LuaRunner) AddInitHook(hook InitHook) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.initHooks = append(r.initHooks, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFinalizeHook adds a hook to be called after script execution
|
||||||
|
func (r *LuaRunner) AddFinalizeHook(hook FinalizeHook) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.finalizeHooks = append(r.finalizeHooks, hook)
|
||||||
|
}
|
||||||
|
|
||||||
// RunWithContext executes a Lua script with context and timeout
|
// RunWithContext executes a Lua script with context and timeout
|
||||||
func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
|
func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
|
||||||
if !r.isRunning.Load() {
|
if !r.isRunning.Load() {
|
||||||
|
@ -246,6 +270,11 @@ func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx
|
||||||
|
|
||||||
r.mu.RLock()
|
r.mu.RLock()
|
||||||
state := r.states[stateIndex]
|
state := r.states[stateIndex]
|
||||||
|
// Copy hooks to ensure we don't hold the lock during execution
|
||||||
|
initHooks := make([]InitHook, len(r.initHooks))
|
||||||
|
copy(initHooks, r.initHooks)
|
||||||
|
finalizeHooks := make([]FinalizeHook, len(r.finalizeHooks))
|
||||||
|
copy(finalizeHooks, r.finalizeHooks)
|
||||||
r.mu.RUnlock()
|
r.mu.RUnlock()
|
||||||
|
|
||||||
if state == nil {
|
if state == nil {
|
||||||
|
@ -258,6 +287,19 @@ func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx
|
||||||
r.mu.Unlock()
|
r.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run init hooks
|
||||||
|
for _, hook := range initHooks {
|
||||||
|
if err := hook(state.state, execCtx); err != nil {
|
||||||
|
result = JobResult{nil, err}
|
||||||
|
// Send result and return early
|
||||||
|
select {
|
||||||
|
case resultChan <- result:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Convert context
|
// Convert context
|
||||||
var ctxMap map[string]any
|
var ctxMap map[string]any
|
||||||
if execCtx != nil {
|
if execCtx != nil {
|
||||||
|
@ -266,6 +308,16 @@ func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
value, err := state.sandbox.Execute(state.state, bytecode, ctxMap)
|
value, err := state.sandbox.Execute(state.state, bytecode, ctxMap)
|
||||||
|
|
||||||
|
// Run finalize hooks
|
||||||
|
for _, hook := range finalizeHooks {
|
||||||
|
hookErr := hook(state.state, execCtx, value)
|
||||||
|
if hookErr != nil && err == nil {
|
||||||
|
// Only override nil errors
|
||||||
|
err = hookErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
result = JobResult{value, err}
|
result = JobResult{value, err}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
176
core/runner/Session.go
Normal file
176
core/runner/Session.go
Normal file
|
@ -0,0 +1,176 @@
|
||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
|
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LuaSessionModule provides session functionality to Lua scripts
|
||||||
|
const LuaSessionModule = `
|
||||||
|
-- Global table to store session data
|
||||||
|
__session_data = __session_data or {}
|
||||||
|
__session_id = __session_id or nil
|
||||||
|
__session_modified = false
|
||||||
|
|
||||||
|
-- Session module implementation
|
||||||
|
local session = {
|
||||||
|
-- Get a session value
|
||||||
|
get = function(key)
|
||||||
|
if type(key) ~= "string" then
|
||||||
|
error("session.get: key must be a string", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
if __session_data and __session_data[key] then
|
||||||
|
return __session_data[key]
|
||||||
|
end
|
||||||
|
|
||||||
|
return nil
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Set a session value
|
||||||
|
set = function(key, value)
|
||||||
|
if type(key) ~= "string" then
|
||||||
|
error("session.set: key must be a string", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Ensure session data table exists
|
||||||
|
__session_data = __session_data or {}
|
||||||
|
|
||||||
|
-- Store value
|
||||||
|
__session_data[key] = value
|
||||||
|
|
||||||
|
-- Mark session as modified
|
||||||
|
__session_modified = true
|
||||||
|
|
||||||
|
return true
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Delete a session value
|
||||||
|
delete = function(key)
|
||||||
|
if type(key) ~= "string" then
|
||||||
|
error("session.delete: key must be a string", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
if __session_data then
|
||||||
|
__session_data[key] = nil
|
||||||
|
__session_modified = true
|
||||||
|
end
|
||||||
|
|
||||||
|
return true
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Clear all session data
|
||||||
|
clear = function()
|
||||||
|
__session_data = {}
|
||||||
|
__session_modified = true
|
||||||
|
return true
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Get the session ID
|
||||||
|
get_id = function()
|
||||||
|
return __session_id or nil
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Get all session data
|
||||||
|
get_all = function()
|
||||||
|
local result = {}
|
||||||
|
for k, v in pairs(__session_data or {}) do
|
||||||
|
result[k] = v
|
||||||
|
end
|
||||||
|
return result
|
||||||
|
end,
|
||||||
|
|
||||||
|
-- Check if session has a key
|
||||||
|
has = function(key)
|
||||||
|
if type(key) ~= "string" then
|
||||||
|
error("session.has: key must be a string", 2)
|
||||||
|
end
|
||||||
|
|
||||||
|
return __session_data and __session_data[key] ~= nil
|
||||||
|
end
|
||||||
|
}
|
||||||
|
|
||||||
|
-- Install session module
|
||||||
|
_G.session = session
|
||||||
|
|
||||||
|
-- Make sure the session module is accessible in sandbox
|
||||||
|
if __env_system and __env_system.base_env then
|
||||||
|
__env_system.base_env.session = session
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Hook into script execution to preserve session state
|
||||||
|
local old_execute_script = __execute_script
|
||||||
|
if old_execute_script then
|
||||||
|
__execute_script = function(fn, ctx)
|
||||||
|
-- Reset modification flag at the start of request
|
||||||
|
__session_modified = false
|
||||||
|
|
||||||
|
-- Execute original function
|
||||||
|
return old_execute_script(fn, ctx)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
`
|
||||||
|
|
||||||
|
// GetSessionData extracts session data from Lua state
|
||||||
|
func GetSessionData(state *luajit.State) (string, map[string]any, bool) {
|
||||||
|
// Check if session was modified
|
||||||
|
state.GetGlobal("__session_modified")
|
||||||
|
modified := state.ToBoolean(-1)
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
if !modified {
|
||||||
|
return "", nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get session ID
|
||||||
|
state.GetGlobal("__session_id")
|
||||||
|
sessionID := state.ToString(-1)
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
// Get session data
|
||||||
|
state.GetGlobal("__session_data")
|
||||||
|
if !state.IsTable(-1) {
|
||||||
|
state.Pop(1)
|
||||||
|
return sessionID, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := state.ToTable(-1)
|
||||||
|
state.Pop(1)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to extract session data: %v", err)
|
||||||
|
return sessionID, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return sessionID, data, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSessionData sets session data in Lua state
|
||||||
|
func SetSessionData(state *luajit.State, sessionID string, data map[string]any) error {
|
||||||
|
// Set session ID
|
||||||
|
state.PushString(sessionID)
|
||||||
|
state.SetGlobal("__session_id")
|
||||||
|
|
||||||
|
// Set session data
|
||||||
|
if data == nil {
|
||||||
|
data = make(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := state.PushTable(data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
state.SetGlobal("__session_data")
|
||||||
|
|
||||||
|
// Reset modification flag
|
||||||
|
state.PushBoolean(false)
|
||||||
|
state.SetGlobal("__session_modified")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionModuleInitFunc returns an initializer for the session module
|
||||||
|
func SessionModuleInitFunc() StateInitFunc {
|
||||||
|
return func(state *luajit.State) error {
|
||||||
|
return state.DoString(LuaSessionModule)
|
||||||
|
}
|
||||||
|
}
|
186
core/runner/SessionHandler.go
Normal file
186
core/runner/SessionHandler.go
Normal file
|
@ -0,0 +1,186 @@
|
||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
|
"git.sharkk.net/Sky/Moonshark/core/logger"
|
||||||
|
"git.sharkk.net/Sky/Moonshark/core/sessions"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SessionHandler handles session management for Lua scripts
|
||||||
|
type SessionHandler struct {
|
||||||
|
manager *sessions.SessionManager
|
||||||
|
debugLog bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSessionHandler creates a new session handler
|
||||||
|
func NewSessionHandler(manager *sessions.SessionManager) *SessionHandler {
|
||||||
|
return &SessionHandler{
|
||||||
|
manager: manager,
|
||||||
|
debugLog: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableDebug enables debug logging
|
||||||
|
func (h *SessionHandler) EnableDebug() {
|
||||||
|
h.debugLog = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// debug logs a message if debug is enabled
|
||||||
|
func (h *SessionHandler) debug(format string, args ...interface{}) {
|
||||||
|
if h.debugLog {
|
||||||
|
logger.Debug("[SessionHandler] "+format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSessionManager creates a RunnerOption to add session support
|
||||||
|
func WithSessionManager(manager *sessions.SessionManager) RunnerOption {
|
||||||
|
return func(r *LuaRunner) {
|
||||||
|
handler := NewSessionHandler(manager)
|
||||||
|
|
||||||
|
// Register the session module
|
||||||
|
RegisterCoreModule("session", SessionModuleInitFunc())
|
||||||
|
|
||||||
|
// Add hooks to the runner
|
||||||
|
r.AddInitHook(handler.preRequestHook)
|
||||||
|
r.AddFinalizeHook(handler.postRequestHook)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// preRequestHook is called before executing a request
|
||||||
|
func (h *SessionHandler) preRequestHook(state *luajit.State, ctx *Context) error {
|
||||||
|
h.debug("Running pre-request session hook")
|
||||||
|
|
||||||
|
// Check if we have cookie information in context
|
||||||
|
// Instead of raw request, we now look for the cookie map
|
||||||
|
if ctx == nil || ctx.Values["_request_cookies"] == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract cookies from context
|
||||||
|
cookies, ok := ctx.Values["_request_cookies"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the session ID from cookies
|
||||||
|
cookieName := h.manager.CookieOptions()["name"].(string)
|
||||||
|
var sessionID string
|
||||||
|
|
||||||
|
// Check if our session cookie exists
|
||||||
|
if cookieValue, exists := cookies[cookieName]; exists {
|
||||||
|
if strValue, ok := cookieValue.(string); ok && strValue != "" {
|
||||||
|
sessionID = strValue
|
||||||
|
h.debug("Found existing session ID: %s", sessionID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no session ID found, create new session
|
||||||
|
if sessionID == "" {
|
||||||
|
// Create a new session
|
||||||
|
session := h.manager.CreateSession()
|
||||||
|
sessionID = session.ID
|
||||||
|
h.debug("Created new session with ID: %s", sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the session ID in the context for later use
|
||||||
|
ctx.Set("_session_id", sessionID)
|
||||||
|
|
||||||
|
// Get the session data
|
||||||
|
session := h.manager.GetSession(sessionID)
|
||||||
|
sessionData := session.GetAll()
|
||||||
|
|
||||||
|
// Set session data in Lua state
|
||||||
|
if err := SetSessionData(state, sessionID, sessionData); err != nil {
|
||||||
|
h.debug("Failed to set session data: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
h.debug("Session data initialized successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// postRequestHook is called after executing a request
|
||||||
|
func (h *SessionHandler) postRequestHook(state *luajit.State, ctx *Context, result any) error {
|
||||||
|
h.debug("Running post-request session hook")
|
||||||
|
|
||||||
|
// Check if session was modified
|
||||||
|
modifiedID, modifiedData, modified := GetSessionData(state)
|
||||||
|
if !modified {
|
||||||
|
h.debug("Session not modified, skipping")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the original session ID from context
|
||||||
|
var sessionID string
|
||||||
|
if ctx != nil {
|
||||||
|
if id, ok := ctx.Values["_session_id"].(string); ok {
|
||||||
|
sessionID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the original session ID if the modified one is empty
|
||||||
|
if modifiedID == "" {
|
||||||
|
modifiedID = sessionID
|
||||||
|
}
|
||||||
|
|
||||||
|
if modifiedID == "" {
|
||||||
|
h.debug("No session ID found, cannot persist session data")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
h.debug("Persisting modified session data for ID: %s", modifiedID)
|
||||||
|
|
||||||
|
// Update session in manager
|
||||||
|
session := h.manager.GetSession(modifiedID)
|
||||||
|
session.Clear() // clear to sync deleted values
|
||||||
|
for k, v := range modifiedData {
|
||||||
|
session.Set(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add session cookie to result if it's an HTTP response
|
||||||
|
if httpResp, ok := result.(*HTTPResponse); ok {
|
||||||
|
h.addSessionCookie(httpResp, modifiedID)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.debug("Session data persisted successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addSessionCookie adds a session cookie to an HTTP response
|
||||||
|
func (h *SessionHandler) addSessionCookie(resp *HTTPResponse, sessionID string) {
|
||||||
|
// Get cookie options
|
||||||
|
opts := h.manager.CookieOptions()
|
||||||
|
|
||||||
|
// Check if session cookie is already set
|
||||||
|
cookieName := opts["name"].(string)
|
||||||
|
for _, cookie := range resp.Cookies {
|
||||||
|
if cookie.Name == cookieName {
|
||||||
|
h.debug("Session cookie already set in response")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h.debug("Adding session cookie to response")
|
||||||
|
|
||||||
|
// Create and add cookie
|
||||||
|
cookie := &http.Cookie{
|
||||||
|
Name: cookieName,
|
||||||
|
Value: sessionID,
|
||||||
|
Path: opts["path"].(string),
|
||||||
|
HttpOnly: opts["http_only"].(bool),
|
||||||
|
MaxAge: opts["max_age"].(int),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional cookie parameters
|
||||||
|
if domain, ok := opts["domain"].(string); ok && domain != "" {
|
||||||
|
cookie.Domain = domain
|
||||||
|
}
|
||||||
|
|
||||||
|
if secure, ok := opts["secure"].(bool); ok {
|
||||||
|
cookie.Secure = secure
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.Cookies = append(resp.Cookies, cookie)
|
||||||
|
}
|
71
core/sessions/Session.go
Normal file
71
core/sessions/Session.go
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
package sessions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Session stores data for a single user session
|
||||||
|
type Session struct {
|
||||||
|
ID string
|
||||||
|
Data map[string]any
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
mu sync.RWMutex // Protect concurrent access to Data
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSession creates a new session with the given ID
|
||||||
|
func NewSession(id string) *Session {
|
||||||
|
now := time.Now()
|
||||||
|
return &Session{
|
||||||
|
ID: id,
|
||||||
|
Data: make(map[string]any),
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a value from the session
|
||||||
|
func (s *Session) Get(key string) any {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.Data[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set stores a value in the session
|
||||||
|
func (s *Session) Set(key string, value any) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.Data[key] = value
|
||||||
|
s.UpdatedAt = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a value from the session
|
||||||
|
func (s *Session) Delete(key string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.Data, key)
|
||||||
|
s.UpdatedAt = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes all data from the session
|
||||||
|
func (s *Session) Clear() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.Data = make(map[string]any)
|
||||||
|
s.UpdatedAt = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAll returns a copy of all session data
|
||||||
|
func (s *Session) GetAll() map[string]any {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
// Create a copy to avoid concurrent map access issues
|
||||||
|
copy := make(map[string]any, len(s.Data))
|
||||||
|
for k, v := range s.Data {
|
||||||
|
copy[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return copy
|
||||||
|
}
|
153
core/sessions/SessionManager.go
Normal file
153
core/sessions/SessionManager.go
Normal file
|
@ -0,0 +1,153 @@
|
||||||
|
package sessions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SessionManager handles multiple sessions
|
||||||
|
type SessionManager struct {
|
||||||
|
sessions map[string]*Session
|
||||||
|
mu sync.RWMutex
|
||||||
|
cookieName string
|
||||||
|
cookiePath string
|
||||||
|
cookieDomain string
|
||||||
|
cookieSecure bool
|
||||||
|
cookieHTTPOnly bool
|
||||||
|
cookieMaxAge int
|
||||||
|
gcInterval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSessionManager creates a new session manager
|
||||||
|
func NewSessionManager() *SessionManager {
|
||||||
|
sm := &SessionManager{
|
||||||
|
sessions: make(map[string]*Session),
|
||||||
|
cookieName: "MSESSID",
|
||||||
|
cookiePath: "/",
|
||||||
|
cookieHTTPOnly: true,
|
||||||
|
cookieMaxAge: 86400, // 1 day
|
||||||
|
gcInterval: time.Hour,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the garbage collector
|
||||||
|
go sm.startGC()
|
||||||
|
|
||||||
|
return sm
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateSessionID creates a cryptographically secure random session ID
|
||||||
|
func (sm *SessionManager) generateSessionID() string {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return time.Now().String() // Fallback
|
||||||
|
}
|
||||||
|
return base64.URLEncoding.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSession retrieves a session by ID, or creates a new one if it doesn't exist
|
||||||
|
func (sm *SessionManager) GetSession(id string) *Session {
|
||||||
|
sm.mu.RLock()
|
||||||
|
session, exists := sm.sessions[id]
|
||||||
|
sm.mu.RUnlock()
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new session if it doesn't exist
|
||||||
|
sm.mu.Lock()
|
||||||
|
defer sm.mu.Unlock()
|
||||||
|
|
||||||
|
// Double check to avoid race conditions
|
||||||
|
if session, exists = sm.sessions[id]; exists {
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
|
||||||
|
session = NewSession(id)
|
||||||
|
sm.sessions[id] = session
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSession generates a new session with a unique ID
|
||||||
|
func (sm *SessionManager) CreateSession() *Session {
|
||||||
|
id := sm.generateSessionID()
|
||||||
|
|
||||||
|
sm.mu.Lock()
|
||||||
|
defer sm.mu.Unlock()
|
||||||
|
|
||||||
|
session := NewSession(id)
|
||||||
|
sm.sessions[id] = session
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
|
||||||
|
// DestroySession removes a session
|
||||||
|
func (sm *SessionManager) DestroySession(id string) {
|
||||||
|
sm.mu.Lock()
|
||||||
|
defer sm.mu.Unlock()
|
||||||
|
delete(sm.sessions, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// startGC starts the garbage collector to clean up expired sessions
|
||||||
|
func (sm *SessionManager) startGC() {
|
||||||
|
ticker := time.NewTicker(sm.gcInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
sm.gc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// gc removes expired sessions (inactive for 24 hours)
|
||||||
|
func (sm *SessionManager) gc() {
|
||||||
|
expiry := time.Now().Add(-24 * time.Hour)
|
||||||
|
|
||||||
|
sm.mu.Lock()
|
||||||
|
defer sm.mu.Unlock()
|
||||||
|
|
||||||
|
for id, session := range sm.sessions {
|
||||||
|
session.mu.RLock()
|
||||||
|
lastUpdated := session.UpdatedAt
|
||||||
|
session.mu.RUnlock()
|
||||||
|
|
||||||
|
if lastUpdated.Before(expiry) {
|
||||||
|
delete(sm.sessions, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSessionCount returns the number of active sessions
|
||||||
|
func (sm *SessionManager) GetSessionCount() int {
|
||||||
|
sm.mu.RLock()
|
||||||
|
defer sm.mu.RUnlock()
|
||||||
|
return len(sm.sessions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CookieOptions returns the cookie options for this session manager
|
||||||
|
func (sm *SessionManager) CookieOptions() map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"name": sm.cookieName,
|
||||||
|
"path": sm.cookiePath,
|
||||||
|
"domain": sm.cookieDomain,
|
||||||
|
"secure": sm.cookieSecure,
|
||||||
|
"http_only": sm.cookieHTTPOnly,
|
||||||
|
"max_age": sm.cookieMaxAge,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCookieOptions configures cookie parameters
|
||||||
|
func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) {
|
||||||
|
sm.mu.Lock()
|
||||||
|
defer sm.mu.Unlock()
|
||||||
|
|
||||||
|
sm.cookieName = name
|
||||||
|
sm.cookiePath = path
|
||||||
|
sm.cookieDomain = domain
|
||||||
|
sm.cookieSecure = secure
|
||||||
|
sm.cookieHTTPOnly = httpOnly
|
||||||
|
sm.cookieMaxAge = maxAge
|
||||||
|
}
|
||||||
|
|
||||||
|
// GlobalSessionManager is the default session manager instance
|
||||||
|
var GlobalSessionManager = NewSessionManager()
|
Loading…
Reference in New Issue
Block a user