Sushi/session.go

291 lines
5.6 KiB
Go

package sushi
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"os"
"sync"
"time"
)
const (
DefaultExpiration = 24 * time.Hour
IDLength = 32
SessionCookieName = "session_id"
SessionCtxKey = "session"
)
// Session represents a user session
type Session struct {
ID string `json:"id"`
UserID int `json:"user_id"`
ExpiresAt int64 `json:"expires_at"`
Data map[string]any `json:"data"`
}
// SessionManager handles session storage and persistence
type SessionManager struct {
mu sync.RWMutex
sessions map[string]*Session
filePath string
}
type sessionData struct {
UserID int `json:"user_id"`
ExpiresAt int64 `json:"expires_at"`
Data map[string]any `json:"data"`
}
var sessionManager *SessionManager
// InitSessions initializes the global session manager
func InitSessions(filePath string) {
if sessionManager != nil {
panic("session manager already initialized")
}
sessionManager = &SessionManager{
sessions: make(map[string]*Session),
filePath: filePath,
}
sessionManager.load()
}
// NewSession creates a new session
func NewSession(userID int) *Session {
return &Session{
ID: generateSessionID(),
UserID: userID,
ExpiresAt: time.Now().Add(DefaultExpiration).Unix(),
Data: make(map[string]any),
}
}
func generateSessionID() string {
bytes := make([]byte, IDLength)
rand.Read(bytes)
return hex.EncodeToString(bytes)
}
// Session methods
func (s *Session) IsExpired() bool {
return time.Now().Unix() > s.ExpiresAt
}
func (s *Session) Touch() {
s.ExpiresAt = time.Now().Add(DefaultExpiration).Unix()
}
func (s *Session) Set(key string, value any) {
s.Data[key] = value
}
func (s *Session) Get(key string) (any, bool) {
value, exists := s.Data[key]
return value, exists
}
func (s *Session) Delete(key string) {
delete(s.Data, key)
}
func (s *Session) SetFlash(key string, value any) {
s.Set("flash_"+key, value)
}
func (s *Session) GetFlash(key string) (any, bool) {
flashKey := "flash_" + key
value, exists := s.Get(flashKey)
if exists {
s.Delete(flashKey)
}
return value, exists
}
func (s *Session) DeleteFlash(key string) {
s.Delete("flash_" + key)
}
func (s *Session) GetFlashMessage(key string) string {
if flash, exists := s.GetFlash(key); exists {
if msg, ok := flash.(string); ok {
return msg
}
}
return ""
}
func (s *Session) RegenerateID() {
oldID := s.ID
s.ID = generateSessionID()
if sessionManager != nil {
sessionManager.mu.Lock()
delete(sessionManager.sessions, oldID)
sessionManager.sessions[s.ID] = s
sessionManager.mu.Unlock()
}
}
func (s *Session) SetUserID(userID int) {
s.UserID = userID
}
// GetCurrentSession retrieves the session from context
func GetCurrentSession(ctx Ctx) *Session {
if sess, ok := ctx.UserValue(SessionCtxKey).(*Session); ok {
return sess
}
return nil
}
// SessionManager methods
func (sm *SessionManager) Create(userID int) *Session {
sess := NewSession(userID)
sm.mu.Lock()
sm.sessions[sess.ID] = sess
sm.mu.Unlock()
return sess
}
func (sm *SessionManager) Get(sessionID string) (*Session, bool) {
sm.mu.RLock()
sess, exists := sm.sessions[sessionID]
sm.mu.RUnlock()
if !exists || sess.IsExpired() {
if exists {
sm.Delete(sessionID)
}
return nil, false
}
return sess, true
}
func (sm *SessionManager) Store(sess *Session) {
sm.mu.Lock()
sm.sessions[sess.ID] = sess
sm.mu.Unlock()
}
func (sm *SessionManager) Delete(sessionID string) {
sm.mu.Lock()
delete(sm.sessions, sessionID)
sm.mu.Unlock()
}
func (sm *SessionManager) Cleanup() {
sm.mu.Lock()
for id, sess := range sm.sessions {
if sess.IsExpired() {
delete(sm.sessions, id)
}
}
sm.mu.Unlock()
}
func (sm *SessionManager) load() {
if sm.filePath == "" {
return
}
data, err := os.ReadFile(sm.filePath)
if err != nil {
return
}
var sessionsData map[string]*sessionData
if err := json.Unmarshal(data, &sessionsData); err != nil {
return
}
now := time.Now().Unix()
sm.mu.Lock()
for id, data := range sessionsData {
if data != nil && data.ExpiresAt > now {
sess := &Session{
ID: id,
UserID: data.UserID,
ExpiresAt: data.ExpiresAt,
Data: data.Data,
}
if sess.Data == nil {
sess.Data = make(map[string]any)
}
sm.sessions[id] = sess
}
}
sm.mu.Unlock()
}
func (sm *SessionManager) Save() error {
if sm.filePath == "" {
return nil
}
sm.Cleanup()
sm.mu.RLock()
sessionsData := make(map[string]*sessionData, len(sm.sessions))
for id, sess := range sm.sessions {
sessionsData[id] = &sessionData{
UserID: sess.UserID,
ExpiresAt: sess.ExpiresAt,
Data: sess.Data,
}
}
data, err := json.MarshalIndent(sessionsData, "", "\t")
sm.mu.RUnlock()
if err != nil {
return err
}
return os.WriteFile(sm.filePath, data, 0600)
}
// Package-level session functions
func CreateSession(userID int) *Session {
return sessionManager.Create(userID)
}
func GetSession(sessionID string) (*Session, bool) {
return sessionManager.Get(sessionID)
}
func StoreSession(sess *Session) {
sessionManager.Store(sess)
}
func CleanupSessions() {
sessionManager.Cleanup()
}
func SaveSessions() error {
return sessionManager.Save()
}
func SetSessionCookie(ctx Ctx, sessionID string) {
SetSecureCookie(ctx, CookieOptions{
Name: SessionCookieName,
Value: sessionID,
Path: "/",
Expires: time.Now().Add(24 * time.Hour),
HTTPOnly: true,
Secure: IsHTTPS(ctx),
SameSite: "lax",
})
}
// GetCurrentSession retrieves the session from context
func (ctx Ctx) GetCurrentSession() *Session {
if sess, ok := ctx.UserValue(SessionCtxKey).(*Session); ok {
return sess
}
return nil
}