223 lines
4.1 KiB
Go
223 lines
4.1 KiB
Go
package auth
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"maps"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
SessionCookieName = "dk_session"
|
|
DefaultExpiration = 24 * time.Hour
|
|
SessionIDLength = 32
|
|
)
|
|
|
|
type Session struct {
|
|
ID string `json:"-"` // Exclude from JSON since it's stored as the map key
|
|
UserID int `json:"user_id"`
|
|
Username string `json:"username"`
|
|
Email string `json:"email"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
ExpiresAt time.Time `json:"expires_at"`
|
|
LastSeen time.Time `json:"last_seen"`
|
|
Data map[string]any `json:"data,omitempty"` // For storing additional session data
|
|
}
|
|
|
|
type SessionStore struct {
|
|
mu sync.RWMutex
|
|
sessions map[string]*Session
|
|
filePath string
|
|
saveInterval time.Duration
|
|
stopChan chan struct{}
|
|
}
|
|
|
|
type persistedData struct {
|
|
Sessions map[string]*Session `json:"sessions"`
|
|
SavedAt time.Time `json:"saved_at"`
|
|
}
|
|
|
|
func NewSessionStore(filePath string) *SessionStore {
|
|
store := &SessionStore{
|
|
sessions: make(map[string]*Session),
|
|
filePath: filePath,
|
|
saveInterval: 5 * time.Minute,
|
|
stopChan: make(chan struct{}),
|
|
}
|
|
|
|
store.loadFromFile()
|
|
store.startPeriodicSave()
|
|
|
|
return store
|
|
}
|
|
|
|
func (s *SessionStore) generateSessionID() string {
|
|
bytes := make([]byte, SessionIDLength)
|
|
rand.Read(bytes)
|
|
return hex.EncodeToString(bytes)
|
|
}
|
|
|
|
func (s *SessionStore) Create(userID int, username, email string) *Session {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
session := &Session{
|
|
ID: s.generateSessionID(),
|
|
UserID: userID,
|
|
Username: username,
|
|
Email: email,
|
|
CreatedAt: time.Now(),
|
|
ExpiresAt: time.Now().Add(DefaultExpiration),
|
|
LastSeen: time.Now(),
|
|
}
|
|
|
|
s.sessions[session.ID] = session
|
|
return session
|
|
}
|
|
|
|
func (s *SessionStore) Get(sessionID string) (*Session, bool) {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
session, exists := s.sessions[sessionID]
|
|
if !exists {
|
|
return nil, false
|
|
}
|
|
|
|
if time.Now().After(session.ExpiresAt) {
|
|
delete(s.sessions, sessionID)
|
|
return nil, false
|
|
}
|
|
|
|
return session, true
|
|
}
|
|
|
|
func (s *SessionStore) Update(sessionID string) bool {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
session, exists := s.sessions[sessionID]
|
|
if !exists {
|
|
return false
|
|
}
|
|
|
|
if time.Now().After(session.ExpiresAt) {
|
|
delete(s.sessions, sessionID)
|
|
return false
|
|
}
|
|
|
|
session.LastSeen = time.Now()
|
|
session.ExpiresAt = time.Now().Add(DefaultExpiration)
|
|
return true
|
|
}
|
|
|
|
func (s *SessionStore) Delete(sessionID string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
delete(s.sessions, sessionID)
|
|
}
|
|
|
|
func (s *SessionStore) Cleanup() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
for id, session := range s.sessions {
|
|
if now.After(session.ExpiresAt) {
|
|
delete(s.sessions, id)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *SessionStore) loadFromFile() {
|
|
if s.filePath == "" {
|
|
return
|
|
}
|
|
|
|
data, err := os.ReadFile(s.filePath)
|
|
if err != nil {
|
|
return // File might not exist yet
|
|
}
|
|
|
|
var persisted persistedData
|
|
if err := json.Unmarshal(data, &persisted); err != nil {
|
|
return
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
for id, session := range persisted.Sessions {
|
|
if now.Before(session.ExpiresAt) {
|
|
s.sessions[id] = session
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *SessionStore) saveToFile() error {
|
|
if s.filePath == "" {
|
|
return nil
|
|
}
|
|
|
|
s.mu.RLock()
|
|
sessionsCopy := make(map[string]*Session)
|
|
maps.Copy(sessionsCopy, s.sessions)
|
|
s.mu.RUnlock()
|
|
|
|
data := persistedData{
|
|
Sessions: sessionsCopy,
|
|
SavedAt: time.Now(),
|
|
}
|
|
|
|
jsonData, err := json.MarshalIndent(data, "", " ")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return os.WriteFile(s.filePath, jsonData, 0600)
|
|
}
|
|
|
|
func (s *SessionStore) startPeriodicSave() {
|
|
go func() {
|
|
ticker := time.NewTicker(s.saveInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
s.Cleanup()
|
|
s.saveToFile()
|
|
case <-s.stopChan:
|
|
s.saveToFile()
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (s *SessionStore) Close() error {
|
|
close(s.stopChan)
|
|
return s.saveToFile()
|
|
}
|
|
|
|
func (s *SessionStore) Stats() (total, active int) {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
now := time.Now()
|
|
total = len(s.sessions)
|
|
|
|
for _, session := range s.sessions {
|
|
if now.Before(session.ExpiresAt) {
|
|
active++
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|