223 lines
4.1 KiB
Go

package auth
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"maps"
"os"
"sync"
"time"
)
const (
SessionCookieName = "dk_session"
DefaultExpiration = 24 * time.Hour
SessionIDLength = 32
)
type Session struct {
ID string `json:"-"` // Exclude from JSON since it's stored as the map key
UserID int `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
LastSeen time.Time `json:"last_seen"`
Data map[string]any `json:"data,omitempty"` // For storing additional session data
}
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*Session
filePath string
saveInterval time.Duration
stopChan chan struct{}
}
type persistedData struct {
Sessions map[string]*Session `json:"sessions"`
SavedAt time.Time `json:"saved_at"`
}
func NewSessionStore(filePath string) *SessionStore {
store := &SessionStore{
sessions: make(map[string]*Session),
filePath: filePath,
saveInterval: 5 * time.Minute,
stopChan: make(chan struct{}),
}
store.loadFromFile()
store.startPeriodicSave()
return store
}
func (s *SessionStore) generateSessionID() string {
bytes := make([]byte, SessionIDLength)
rand.Read(bytes)
return hex.EncodeToString(bytes)
}
func (s *SessionStore) Create(userID int, username, email string) *Session {
s.mu.Lock()
defer s.mu.Unlock()
session := &Session{
ID: s.generateSessionID(),
UserID: userID,
Username: username,
Email: email,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(DefaultExpiration),
LastSeen: time.Now(),
}
s.sessions[session.ID] = session
return session
}
func (s *SessionStore) Get(sessionID string) (*Session, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, exists := s.sessions[sessionID]
if !exists {
return nil, false
}
if time.Now().After(session.ExpiresAt) {
delete(s.sessions, sessionID)
return nil, false
}
return session, true
}
func (s *SessionStore) Update(sessionID string) bool {
s.mu.Lock()
defer s.mu.Unlock()
session, exists := s.sessions[sessionID]
if !exists {
return false
}
if time.Now().After(session.ExpiresAt) {
delete(s.sessions, sessionID)
return false
}
session.LastSeen = time.Now()
session.ExpiresAt = time.Now().Add(DefaultExpiration)
return true
}
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
func (s *SessionStore) Cleanup() {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
for id, session := range s.sessions {
if now.After(session.ExpiresAt) {
delete(s.sessions, id)
}
}
}
func (s *SessionStore) loadFromFile() {
if s.filePath == "" {
return
}
data, err := os.ReadFile(s.filePath)
if err != nil {
return // File might not exist yet
}
var persisted persistedData
if err := json.Unmarshal(data, &persisted); err != nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
for id, session := range persisted.Sessions {
if now.Before(session.ExpiresAt) {
s.sessions[id] = session
}
}
}
func (s *SessionStore) saveToFile() error {
if s.filePath == "" {
return nil
}
s.mu.RLock()
sessionsCopy := make(map[string]*Session)
maps.Copy(sessionsCopy, s.sessions)
s.mu.RUnlock()
data := persistedData{
Sessions: sessionsCopy,
SavedAt: time.Now(),
}
jsonData, err := json.MarshalIndent(data, "", " ")
if err != nil {
return err
}
return os.WriteFile(s.filePath, jsonData, 0600)
}
func (s *SessionStore) startPeriodicSave() {
go func() {
ticker := time.NewTicker(s.saveInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.Cleanup()
s.saveToFile()
case <-s.stopChan:
s.saveToFile()
return
}
}
}()
}
func (s *SessionStore) Close() error {
close(s.stopChan)
return s.saveToFile()
}
func (s *SessionStore) Stats() (total, active int) {
s.mu.RLock()
defer s.mu.RUnlock()
now := time.Now()
total = len(s.sessions)
for _, session := range s.sessions {
if now.Before(session.ExpiresAt) {
active++
}
}
return
}