package sushi import ( "crypto/rand" "encoding/hex" "encoding/json" "os" "sync" "time" ) const ( DefaultExpiration = 24 * time.Hour IDLength = 32 SessionCookieName = "session_id" SessionCtxKey = "session" ) // Session represents a user session type Session struct { ID string `json:"id"` UserID int `json:"user_id"` ExpiresAt int64 `json:"expires_at"` Data map[string]any `json:"data"` } // SessionManager handles session storage and persistence type SessionManager struct { mu sync.RWMutex sessions map[string]*Session filePath string } type sessionData struct { UserID int `json:"user_id"` ExpiresAt int64 `json:"expires_at"` Data map[string]any `json:"data"` } var sessionManager *SessionManager // InitSessions initializes the global session manager func InitSessions(filePath string) { if sessionManager != nil { panic("session manager already initialized") } sessionManager = &SessionManager{ sessions: make(map[string]*Session), filePath: filePath, } sessionManager.load() } // NewSession creates a new session func NewSession(userID int) *Session { return &Session{ ID: generateSessionID(), UserID: userID, ExpiresAt: time.Now().Add(DefaultExpiration).Unix(), Data: make(map[string]any), } } func generateSessionID() string { bytes := make([]byte, IDLength) rand.Read(bytes) return hex.EncodeToString(bytes) } // Session methods func (s *Session) IsExpired() bool { return time.Now().Unix() > s.ExpiresAt } func (s *Session) Touch() { s.ExpiresAt = time.Now().Add(DefaultExpiration).Unix() } func (s *Session) Set(key string, value any) { s.Data[key] = value } func (s *Session) Get(key string) (any, bool) { value, exists := s.Data[key] return value, exists } func (s *Session) Delete(key string) { delete(s.Data, key) } func (s *Session) SetFlash(key string, value any) { s.Set("flash_"+key, value) } func (s *Session) GetFlash(key string) (any, bool) { flashKey := "flash_" + key value, exists := s.Get(flashKey) if exists { s.Delete(flashKey) } return value, exists } func (s *Session) DeleteFlash(key string) { s.Delete("flash_" + key) } func (s *Session) GetFlashMessage(key string) string { if flash, exists := s.GetFlash(key); exists { if msg, ok := flash.(string); ok { return msg } } return "" } func (s *Session) RegenerateID() { oldID := s.ID s.ID = generateSessionID() if sessionManager != nil { sessionManager.mu.Lock() delete(sessionManager.sessions, oldID) sessionManager.sessions[s.ID] = s sessionManager.mu.Unlock() } } func (s *Session) SetUserID(userID int) { s.UserID = userID } // GetCurrentSession retrieves the session from context func GetCurrentSession(ctx Ctx) *Session { if sess, ok := ctx.UserValue(SessionCtxKey).(*Session); ok { return sess } return nil } // SessionManager methods func (sm *SessionManager) Create(userID int) *Session { sess := NewSession(userID) sm.mu.Lock() sm.sessions[sess.ID] = sess sm.mu.Unlock() return sess } func (sm *SessionManager) Get(sessionID string) (*Session, bool) { sm.mu.RLock() sess, exists := sm.sessions[sessionID] sm.mu.RUnlock() if !exists || sess.IsExpired() { if exists { sm.Delete(sessionID) } return nil, false } return sess, true } func (sm *SessionManager) Store(sess *Session) { sm.mu.Lock() sm.sessions[sess.ID] = sess sm.mu.Unlock() } func (sm *SessionManager) Delete(sessionID string) { sm.mu.Lock() delete(sm.sessions, sessionID) sm.mu.Unlock() } func (sm *SessionManager) Cleanup() { sm.mu.Lock() for id, sess := range sm.sessions { if sess.IsExpired() { delete(sm.sessions, id) } } sm.mu.Unlock() } func (sm *SessionManager) load() { if sm.filePath == "" { return } data, err := os.ReadFile(sm.filePath) if err != nil { return } var sessionsData map[string]*sessionData if err := json.Unmarshal(data, &sessionsData); err != nil { return } now := time.Now().Unix() sm.mu.Lock() for id, data := range sessionsData { if data != nil && data.ExpiresAt > now { sess := &Session{ ID: id, UserID: data.UserID, ExpiresAt: data.ExpiresAt, Data: data.Data, } if sess.Data == nil { sess.Data = make(map[string]any) } sm.sessions[id] = sess } } sm.mu.Unlock() } func (sm *SessionManager) Save() error { if sm.filePath == "" { return nil } sm.Cleanup() sm.mu.RLock() sessionsData := make(map[string]*sessionData, len(sm.sessions)) for id, sess := range sm.sessions { sessionsData[id] = &sessionData{ UserID: sess.UserID, ExpiresAt: sess.ExpiresAt, Data: sess.Data, } } data, err := json.MarshalIndent(sessionsData, "", "\t") sm.mu.RUnlock() if err != nil { return err } return os.WriteFile(sm.filePath, data, 0600) } // Package-level session functions func CreateSession(userID int) *Session { return sessionManager.Create(userID) } func GetSession(sessionID string) (*Session, bool) { return sessionManager.Get(sessionID) } func StoreSession(sess *Session) { sessionManager.Store(sess) } func CleanupSessions() { sessionManager.Cleanup() } func SaveSessions() error { return sessionManager.Save() } func SetSessionCookie(ctx Ctx, sessionID string) { SetSecureCookie(ctx, CookieOptions{ Name: SessionCookieName, Value: sessionID, Path: "/", Expires: time.Now().Add(24 * time.Hour), HTTPOnly: true, Secure: IsHTTPS(ctx), SameSite: "lax", }) } // GetCurrentSession retrieves the session from context func (ctx Ctx) GetCurrentSession() *Session { if sess, ok := ctx.UserValue(SessionCtxKey).(*Session); ok { return sess } return nil }