291 lines
5.6 KiB
Go
291 lines
5.6 KiB
Go
package sushi
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
DefaultExpiration = 24 * time.Hour
|
|
IDLength = 32
|
|
SessionCookieName = "session_id"
|
|
SessionCtxKey = "session"
|
|
)
|
|
|
|
// Session represents a user session
|
|
type Session struct {
|
|
ID string `json:"id"`
|
|
UserID int `json:"user_id"`
|
|
ExpiresAt int64 `json:"expires_at"`
|
|
Data map[string]any `json:"data"`
|
|
}
|
|
|
|
// SessionManager handles session storage and persistence
|
|
type SessionManager struct {
|
|
mu sync.RWMutex
|
|
sessions map[string]*Session
|
|
filePath string
|
|
}
|
|
|
|
type sessionData struct {
|
|
UserID int `json:"user_id"`
|
|
ExpiresAt int64 `json:"expires_at"`
|
|
Data map[string]any `json:"data"`
|
|
}
|
|
|
|
var sessionManager *SessionManager
|
|
|
|
// InitSessions initializes the global session manager
|
|
func InitSessions(filePath string) {
|
|
if sessionManager != nil {
|
|
panic("session manager already initialized")
|
|
}
|
|
|
|
sessionManager = &SessionManager{
|
|
sessions: make(map[string]*Session),
|
|
filePath: filePath,
|
|
}
|
|
|
|
sessionManager.load()
|
|
}
|
|
|
|
// NewSession creates a new session
|
|
func NewSession(userID int) *Session {
|
|
return &Session{
|
|
ID: generateSessionID(),
|
|
UserID: userID,
|
|
ExpiresAt: time.Now().Add(DefaultExpiration).Unix(),
|
|
Data: make(map[string]any),
|
|
}
|
|
}
|
|
|
|
func generateSessionID() string {
|
|
bytes := make([]byte, IDLength)
|
|
rand.Read(bytes)
|
|
return hex.EncodeToString(bytes)
|
|
}
|
|
|
|
// Session methods
|
|
func (s *Session) IsExpired() bool {
|
|
return time.Now().Unix() > s.ExpiresAt
|
|
}
|
|
|
|
func (s *Session) Touch() {
|
|
s.ExpiresAt = time.Now().Add(DefaultExpiration).Unix()
|
|
}
|
|
|
|
func (s *Session) Set(key string, value any) {
|
|
s.Data[key] = value
|
|
}
|
|
|
|
func (s *Session) Get(key string) (any, bool) {
|
|
value, exists := s.Data[key]
|
|
return value, exists
|
|
}
|
|
|
|
func (s *Session) Delete(key string) {
|
|
delete(s.Data, key)
|
|
}
|
|
|
|
func (s *Session) SetFlash(key string, value any) {
|
|
s.Set("flash_"+key, value)
|
|
}
|
|
|
|
func (s *Session) GetFlash(key string) (any, bool) {
|
|
flashKey := "flash_" + key
|
|
value, exists := s.Get(flashKey)
|
|
if exists {
|
|
s.Delete(flashKey)
|
|
}
|
|
return value, exists
|
|
}
|
|
|
|
func (s *Session) DeleteFlash(key string) {
|
|
s.Delete("flash_" + key)
|
|
}
|
|
|
|
func (s *Session) GetFlashMessage(key string) string {
|
|
if flash, exists := s.GetFlash(key); exists {
|
|
if msg, ok := flash.(string); ok {
|
|
return msg
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (s *Session) RegenerateID() {
|
|
oldID := s.ID
|
|
s.ID = generateSessionID()
|
|
|
|
if sessionManager != nil {
|
|
sessionManager.mu.Lock()
|
|
delete(sessionManager.sessions, oldID)
|
|
sessionManager.sessions[s.ID] = s
|
|
sessionManager.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
func (s *Session) SetUserID(userID int) {
|
|
s.UserID = userID
|
|
}
|
|
|
|
// GetCurrentSession retrieves the session from context
|
|
func GetCurrentSession(ctx Ctx) *Session {
|
|
if sess, ok := ctx.UserValue(SessionCtxKey).(*Session); ok {
|
|
return sess
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SessionManager methods
|
|
func (sm *SessionManager) Create(userID int) *Session {
|
|
sess := NewSession(userID)
|
|
sm.mu.Lock()
|
|
sm.sessions[sess.ID] = sess
|
|
sm.mu.Unlock()
|
|
return sess
|
|
}
|
|
|
|
func (sm *SessionManager) Get(sessionID string) (*Session, bool) {
|
|
sm.mu.RLock()
|
|
sess, exists := sm.sessions[sessionID]
|
|
sm.mu.RUnlock()
|
|
|
|
if !exists || sess.IsExpired() {
|
|
if exists {
|
|
sm.Delete(sessionID)
|
|
}
|
|
return nil, false
|
|
}
|
|
|
|
return sess, true
|
|
}
|
|
|
|
func (sm *SessionManager) Store(sess *Session) {
|
|
sm.mu.Lock()
|
|
sm.sessions[sess.ID] = sess
|
|
sm.mu.Unlock()
|
|
}
|
|
|
|
func (sm *SessionManager) Delete(sessionID string) {
|
|
sm.mu.Lock()
|
|
delete(sm.sessions, sessionID)
|
|
sm.mu.Unlock()
|
|
}
|
|
|
|
func (sm *SessionManager) Cleanup() {
|
|
sm.mu.Lock()
|
|
for id, sess := range sm.sessions {
|
|
if sess.IsExpired() {
|
|
delete(sm.sessions, id)
|
|
}
|
|
}
|
|
sm.mu.Unlock()
|
|
}
|
|
|
|
func (sm *SessionManager) load() {
|
|
if sm.filePath == "" {
|
|
return
|
|
}
|
|
|
|
data, err := os.ReadFile(sm.filePath)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
var sessionsData map[string]*sessionData
|
|
if err := json.Unmarshal(data, &sessionsData); err != nil {
|
|
return
|
|
}
|
|
|
|
now := time.Now().Unix()
|
|
sm.mu.Lock()
|
|
for id, data := range sessionsData {
|
|
if data != nil && data.ExpiresAt > now {
|
|
sess := &Session{
|
|
ID: id,
|
|
UserID: data.UserID,
|
|
ExpiresAt: data.ExpiresAt,
|
|
Data: data.Data,
|
|
}
|
|
if sess.Data == nil {
|
|
sess.Data = make(map[string]any)
|
|
}
|
|
sm.sessions[id] = sess
|
|
}
|
|
}
|
|
sm.mu.Unlock()
|
|
}
|
|
|
|
func (sm *SessionManager) Save() error {
|
|
if sm.filePath == "" {
|
|
return nil
|
|
}
|
|
|
|
sm.Cleanup()
|
|
|
|
sm.mu.RLock()
|
|
sessionsData := make(map[string]*sessionData, len(sm.sessions))
|
|
for id, sess := range sm.sessions {
|
|
sessionsData[id] = &sessionData{
|
|
UserID: sess.UserID,
|
|
ExpiresAt: sess.ExpiresAt,
|
|
Data: sess.Data,
|
|
}
|
|
}
|
|
|
|
data, err := json.MarshalIndent(sessionsData, "", "\t")
|
|
sm.mu.RUnlock()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return os.WriteFile(sm.filePath, data, 0600)
|
|
}
|
|
|
|
// Package-level session functions
|
|
func CreateSession(userID int) *Session {
|
|
return sessionManager.Create(userID)
|
|
}
|
|
|
|
func GetSession(sessionID string) (*Session, bool) {
|
|
return sessionManager.Get(sessionID)
|
|
}
|
|
|
|
func StoreSession(sess *Session) {
|
|
sessionManager.Store(sess)
|
|
}
|
|
|
|
func CleanupSessions() {
|
|
sessionManager.Cleanup()
|
|
}
|
|
|
|
func SaveSessions() error {
|
|
return sessionManager.Save()
|
|
}
|
|
|
|
func SetSessionCookie(ctx Ctx, sessionID string) {
|
|
SetSecureCookie(ctx, CookieOptions{
|
|
Name: SessionCookieName,
|
|
Value: sessionID,
|
|
Path: "/",
|
|
Expires: time.Now().Add(24 * time.Hour),
|
|
HTTPOnly: true,
|
|
Secure: IsHTTPS(ctx),
|
|
SameSite: "lax",
|
|
})
|
|
}
|
|
|
|
// GetCurrentSession retrieves the session from context
|
|
func (ctx Ctx) GetCurrentSession() *Session {
|
|
if sess, ok := ctx.UserValue(SessionCtxKey).(*Session); ok {
|
|
return sess
|
|
}
|
|
return nil
|
|
}
|