181 lines
3.6 KiB
Go

package session
import (
"encoding/json"
"os"
"sync"
"time"
)
// SessionManager handles session storage and persistence
type SessionManager struct {
mu sync.RWMutex
sessions map[string]*Session
filePath string
}
var Manager *SessionManager
// sessionData represents session data for JSON serialization (excludes ID)
type sessionData struct {
UserID int `json:"user_id"`
ExpiresAt int64 `json:"expires_at"`
Data map[string]any `json:"data"`
}
// Init initializes the global session manager
func Init(filePath string) {
if Manager != nil {
panic("session manager already initialized")
}
Manager = &SessionManager{
sessions: make(map[string]*Session),
filePath: filePath,
}
Manager.load()
}
// GetManager returns the global session manager
func GetManager() *SessionManager {
if Manager == nil {
panic("session manager not initialized")
}
return Manager
}
// Create creates and stores a new session
func (sm *SessionManager) Create(userID int) *Session {
sess := New(userID)
sm.mu.Lock()
sm.sessions[sess.ID] = sess
sm.mu.Unlock()
return sess
}
// Get retrieves a session by ID
func (sm *SessionManager) Get(sessionID string) (*Session, bool) {
sm.mu.RLock()
sess, exists := sm.sessions[sessionID]
sm.mu.RUnlock()
if !exists || sess.IsExpired() {
if exists {
sm.Delete(sessionID)
}
return nil, false
}
return sess, true
}
// Store saves a session in memory (updates existing or creates new)
func (sm *SessionManager) Store(sess *Session) {
sm.mu.Lock()
sm.sessions[sess.ID] = sess
sm.mu.Unlock()
}
// Delete removes a session
func (sm *SessionManager) Delete(sessionID string) {
sm.mu.Lock()
delete(sm.sessions, sessionID)
sm.mu.Unlock()
}
// Cleanup removes expired sessions
func (sm *SessionManager) Cleanup() {
sm.mu.Lock()
for id, sess := range sm.sessions {
if sess.IsExpired() {
delete(sm.sessions, id)
}
}
sm.mu.Unlock()
}
// Stats returns session statistics
func (sm *SessionManager) Stats() (total, active int) {
sm.mu.RLock()
defer sm.mu.RUnlock()
total = len(sm.sessions)
for _, sess := range sm.sessions {
if !sess.IsExpired() {
active++
}
}
return
}
// load reads sessions from the JSON file
func (sm *SessionManager) load() {
if sm.filePath == "" {
return
}
data, err := os.ReadFile(sm.filePath)
if err != nil {
return // File doesn't exist or can't be read
}
var sessionsData map[string]*sessionData
if err := json.Unmarshal(data, &sessionsData); err != nil {
return // Invalid JSON
}
now := time.Now().Unix()
sm.mu.Lock()
for id, data := range sessionsData {
if data != nil && data.ExpiresAt > now {
sess := &Session{
ID: id,
UserID: data.UserID, // Make sure we restore the UserID properly
ExpiresAt: data.ExpiresAt,
Data: data.Data,
}
if sess.Data == nil {
sess.Data = make(map[string]any)
}
sm.sessions[id] = sess
}
}
sm.mu.Unlock()
}
// Save writes sessions to the JSON file
func (sm *SessionManager) Save() error {
if sm.filePath == "" {
return nil
}
sm.Cleanup() // Remove expired sessions before saving
sm.mu.RLock()
// Convert sessions to sessionData (without ID field)
sessionsData := make(map[string]*sessionData, len(sm.sessions))
for id, sess := range sm.sessions {
sessionsData[id] = &sessionData{
UserID: sess.UserID, // Save the actual UserID from the struct
ExpiresAt: sess.ExpiresAt,
Data: sess.Data,
}
}
data, err := json.MarshalIndent(sessionsData, "", "\t")
sm.mu.RUnlock()
if err != nil {
return err
}
return os.WriteFile(sm.filePath, data, 0600)
}
// Close saves sessions and cleans up
func (sm *SessionManager) Close() error {
return sm.Save()
}