separate session into its own package, clean up some docs
This commit is contained in:
parent
b8b77351d0
commit
4a73b7cc0d
@ -1,33 +1,30 @@
|
|||||||
|
// Package auth provides authentication and session management functionality.
|
||||||
|
// It includes secure session storage with in-memory caching and JSON persistence,
|
||||||
|
// user authentication against the database, and secure cookie handling.
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"dk/internal/password"
|
"dk/internal/password"
|
||||||
|
"dk/internal/session"
|
||||||
"dk/internal/users"
|
"dk/internal/users"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Manager is the global singleton instance
|
|
||||||
var Manager *AuthManager
|
var Manager *AuthManager
|
||||||
|
|
||||||
// AuthManager is a wrapper for the session store to add
|
|
||||||
// authentication tools over the store itself
|
|
||||||
type AuthManager struct {
|
type AuthManager struct {
|
||||||
store *SessionStore
|
store *session.Store
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes the global auth manager (auth.Manager)
|
|
||||||
func Init(sessionsFilePath string) {
|
func Init(sessionsFilePath string) {
|
||||||
Manager = &AuthManager{
|
Manager = &AuthManager{
|
||||||
store: NewSessionStore(sessionsFilePath),
|
store: session.NewStore(sessionsFilePath),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate checks for the usernaname or email, then verifies the plain password
|
|
||||||
// against the stored hash.
|
|
||||||
func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*users.User, error) {
|
func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*users.User, error) {
|
||||||
var user *users.User
|
var user *users.User
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Try to find user by username first
|
|
||||||
user, err = users.GetByUsername(usernameOrEmail)
|
user, err = users.GetByUsername(usernameOrEmail)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
user, err = users.GetByEmail(usernameOrEmail)
|
user, err = users.GetByEmail(usernameOrEmail)
|
||||||
@ -47,16 +44,25 @@ func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*use
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *AuthManager) CreateSession(user *users.User) *Session {
|
func (am *AuthManager) CreateSession(user *users.User) *session.Session {
|
||||||
return am.store.Create(user.ID, user.Username, user.Email)
|
sess := session.New(user.ID, user.Username, user.Email)
|
||||||
|
am.store.Save(sess)
|
||||||
|
return sess
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *AuthManager) GetSession(sessionID string) (*Session, bool) {
|
func (am *AuthManager) GetSession(sessionID string) (*session.Session, bool) {
|
||||||
return am.store.Get(sessionID)
|
return am.store.Get(sessionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *AuthManager) UpdateSession(sessionID string) bool {
|
func (am *AuthManager) UpdateSession(sessionID string) bool {
|
||||||
return am.store.Update(sessionID)
|
sess, exists := am.store.Get(sessionID)
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
sess.Touch()
|
||||||
|
am.store.Save(sess)
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *AuthManager) DeleteSession(sessionID string) {
|
func (am *AuthManager) DeleteSession(sessionID string) {
|
||||||
@ -71,124 +77,6 @@ func (am *AuthManager) Close() error {
|
|||||||
return am.store.Close()
|
return am.store.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetFlash stores a flash message in the session that will be removed after retrieval
|
|
||||||
func (am *AuthManager) SetFlash(sessionID, key string, value any) bool {
|
|
||||||
session, exists := am.store.Get(sessionID)
|
|
||||||
if !exists {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
am.store.mu.Lock()
|
|
||||||
defer am.store.mu.Unlock()
|
|
||||||
|
|
||||||
if session.Data == nil {
|
|
||||||
session.Data = make(map[string]any)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store flash messages under a special key
|
|
||||||
flashData, ok := session.Data["_flash"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
flashData = make(map[string]any)
|
|
||||||
}
|
|
||||||
flashData[key] = value
|
|
||||||
session.Data["_flash"] = flashData
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetFlash retrieves and removes a flash message from the session
|
|
||||||
func (am *AuthManager) GetFlash(sessionID, key string) (any, bool) {
|
|
||||||
session, exists := am.store.Get(sessionID)
|
|
||||||
if !exists {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
am.store.mu.Lock()
|
|
||||||
defer am.store.mu.Unlock()
|
|
||||||
|
|
||||||
if session.Data == nil {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
flashData, ok := session.Data["_flash"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
value, exists := flashData[key]
|
|
||||||
if exists {
|
|
||||||
delete(flashData, key)
|
|
||||||
if len(flashData) == 0 {
|
|
||||||
delete(session.Data, "_flash")
|
|
||||||
} else {
|
|
||||||
session.Data["_flash"] = flashData
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return value, exists
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllFlash retrieves and removes all flash messages from the session
|
|
||||||
func (am *AuthManager) GetAllFlash(sessionID string) map[string]any {
|
|
||||||
session, exists := am.store.Get(sessionID)
|
|
||||||
if !exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
am.store.mu.Lock()
|
|
||||||
defer am.store.mu.Unlock()
|
|
||||||
|
|
||||||
if session.Data == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
flashData, ok := session.Data["_flash"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove flash data from session
|
|
||||||
delete(session.Data, "_flash")
|
|
||||||
|
|
||||||
return flashData
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSessionData stores arbitrary data in the session
|
|
||||||
func (am *AuthManager) SetSessionData(sessionID, key string, value any) bool {
|
|
||||||
session, exists := am.store.Get(sessionID)
|
|
||||||
if !exists {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
am.store.mu.Lock()
|
|
||||||
defer am.store.mu.Unlock()
|
|
||||||
|
|
||||||
if session.Data == nil {
|
|
||||||
session.Data = make(map[string]any)
|
|
||||||
}
|
|
||||||
|
|
||||||
session.Data[key] = value
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSessionData retrieves data from the session
|
|
||||||
func (am *AuthManager) GetSessionData(sessionID, key string) (any, bool) {
|
|
||||||
session, exists := am.store.Get(sessionID)
|
|
||||||
if !exists {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
am.store.mu.RLock()
|
|
||||||
defer am.store.mu.RUnlock()
|
|
||||||
|
|
||||||
if session.Data == nil {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
value, exists := session.Data[key]
|
|
||||||
return value, exists
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrInvalidCredentials = &AuthError{"invalid username/email or password"}
|
ErrInvalidCredentials = &AuthError{"invalid username/email or password"}
|
||||||
ErrSessionNotFound = &AuthError{"session not found"}
|
ErrSessionNotFound = &AuthError{"session not found"}
|
||||||
|
@ -2,18 +2,21 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"dk/internal/cookies"
|
"dk/internal/cookies"
|
||||||
|
"dk/internal/session"
|
||||||
"dk/internal/utils"
|
"dk/internal/utils"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const SessionCookieName = "dk_session"
|
||||||
|
|
||||||
func SetSessionCookie(ctx *fasthttp.RequestCtx, sessionID string) {
|
func SetSessionCookie(ctx *fasthttp.RequestCtx, sessionID string) {
|
||||||
cookies.SetSecureCookie(ctx, cookies.CookieOptions{
|
cookies.SetSecureCookie(ctx, cookies.CookieOptions{
|
||||||
Name: SessionCookieName,
|
Name: SessionCookieName,
|
||||||
Value: sessionID,
|
Value: sessionID,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Expires: time.Now().Add(DefaultExpiration),
|
Expires: time.Now().Add(session.DefaultExpiration),
|
||||||
HTTPOnly: true,
|
HTTPOnly: true,
|
||||||
Secure: utils.IsHTTPS(ctx),
|
Secure: utils.IsHTTPS(ctx),
|
||||||
SameSite: "lax",
|
SameSite: "lax",
|
||||||
@ -26,4 +29,4 @@ func GetSessionCookie(ctx *fasthttp.RequestCtx) string {
|
|||||||
|
|
||||||
func DeleteSessionCookie(ctx *fasthttp.RequestCtx) {
|
func DeleteSessionCookie(ctx *fasthttp.RequestCtx) {
|
||||||
cookies.DeleteCookie(ctx, SessionCookieName)
|
cookies.DeleteCookie(ctx, SessionCookieName)
|
||||||
}
|
}
|
@ -1,4 +0,0 @@
|
|||||||
// Package auth provides authentication and session management functionality.
|
|
||||||
// It includes secure session storage with in-memory caching and JSON persistence,
|
|
||||||
// user authentication against the database, and secure cookie handling.
|
|
||||||
package auth
|
|
@ -2,46 +2,52 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"dk/internal/router"
|
"dk/internal/router"
|
||||||
|
"dk/internal/session"
|
||||||
)
|
)
|
||||||
|
|
||||||
// FlashMessage represents a flash message with type and content
|
|
||||||
type FlashMessage struct {
|
|
||||||
Type string `json:"type"` // "error", "success", "warning", "info"
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetFlashMessage sets a flash message for the current session
|
|
||||||
func SetFlashMessage(ctx router.Ctx, msgType, message string) bool {
|
func SetFlashMessage(ctx router.Ctx, msgType, message string) bool {
|
||||||
sessionID := GetSessionCookie(ctx)
|
sessionID := GetSessionCookie(ctx)
|
||||||
if sessionID == "" {
|
if sessionID == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return Manager.SetFlash(sessionID, "message", FlashMessage{
|
sess, exists := Manager.GetSession(sessionID)
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
sess.SetFlash("message", session.FlashMessage{
|
||||||
Type: msgType,
|
Type: msgType,
|
||||||
Message: message,
|
Message: message,
|
||||||
})
|
})
|
||||||
|
Manager.store.Save(sess)
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFlashMessage retrieves and removes the flash message from the current session
|
func GetFlashMessage(ctx router.Ctx) *session.FlashMessage {
|
||||||
func GetFlashMessage(ctx router.Ctx) *FlashMessage {
|
|
||||||
sessionID := GetSessionCookie(ctx)
|
sessionID := GetSessionCookie(ctx)
|
||||||
if sessionID == "" {
|
if sessionID == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
value, exists := Manager.GetFlash(sessionID, "message")
|
sess, exists := Manager.GetSession(sessionID)
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if msg, ok := value.(FlashMessage); ok {
|
value, exists := sess.GetFlash("message")
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
Manager.store.Save(sess)
|
||||||
|
|
||||||
|
if msg, ok := value.(session.FlashMessage); ok {
|
||||||
return &msg
|
return &msg
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle map[string]interface{} from JSON deserialization
|
|
||||||
if msgMap, ok := value.(map[string]interface{}); ok {
|
if msgMap, ok := value.(map[string]interface{}); ok {
|
||||||
msg := &FlashMessage{}
|
msg := &session.FlashMessage{}
|
||||||
if t, ok := msgMap["type"].(string); ok {
|
if t, ok := msgMap["type"].(string); ok {
|
||||||
msg.Type = t
|
msg.Type = t
|
||||||
}
|
}
|
||||||
@ -54,36 +60,45 @@ func GetFlashMessage(ctx router.Ctx) *FlashMessage {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetFormData stores form data temporarily in the session (for repopulating forms after errors)
|
|
||||||
func SetFormData(ctx router.Ctx, data map[string]string) bool {
|
func SetFormData(ctx router.Ctx, data map[string]string) bool {
|
||||||
sessionID := GetSessionCookie(ctx)
|
sessionID := GetSessionCookie(ctx)
|
||||||
if sessionID == "" {
|
if sessionID == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return Manager.SetSessionData(sessionID, "form_data", data)
|
sess, exists := Manager.GetSession(sessionID)
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
sess.Set("form_data", data)
|
||||||
|
Manager.store.Save(sess)
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFormData retrieves and removes form data from the session
|
|
||||||
func GetFormData(ctx router.Ctx) map[string]string {
|
func GetFormData(ctx router.Ctx) map[string]string {
|
||||||
sessionID := GetSessionCookie(ctx)
|
sessionID := GetSessionCookie(ctx)
|
||||||
if sessionID == "" {
|
if sessionID == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
value, exists := Manager.GetSessionData(sessionID, "form_data")
|
sess, exists := Manager.GetSession(sessionID)
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear form data after retrieval
|
value, exists := sess.Get("form_data")
|
||||||
Manager.SetSessionData(sessionID, "form_data", nil)
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sess.Delete("form_data")
|
||||||
|
Manager.store.Save(sess)
|
||||||
|
|
||||||
if formData, ok := value.(map[string]string); ok {
|
if formData, ok := value.(map[string]string); ok {
|
||||||
return formData
|
return formData
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle map[string]interface{} from JSON deserialization
|
|
||||||
if formMap, ok := value.(map[string]interface{}); ok {
|
if formMap, ok := value.(map[string]interface{}); ok {
|
||||||
result := make(map[string]string)
|
result := make(map[string]string)
|
||||||
for k, v := range formMap {
|
for k, v := range formMap {
|
||||||
|
@ -1,222 +0,0 @@
|
|||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
|
||||||
"maps"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
SessionCookieName = "dk_session"
|
|
||||||
DefaultExpiration = 24 * time.Hour
|
|
||||||
SessionIDLength = 32
|
|
||||||
)
|
|
||||||
|
|
||||||
type Session struct {
|
|
||||||
ID string `json:"-"` // Exclude from JSON since it's stored as the map key
|
|
||||||
UserID int `json:"user_id"`
|
|
||||||
Username string `json:"username"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
|
||||||
ExpiresAt time.Time `json:"expires_at"`
|
|
||||||
LastSeen time.Time `json:"last_seen"`
|
|
||||||
Data map[string]any `json:"data,omitempty"` // For storing additional session data
|
|
||||||
}
|
|
||||||
|
|
||||||
type SessionStore struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
sessions map[string]*Session
|
|
||||||
filePath string
|
|
||||||
saveInterval time.Duration
|
|
||||||
stopChan chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
type persistedData struct {
|
|
||||||
Sessions map[string]*Session `json:"sessions"`
|
|
||||||
SavedAt time.Time `json:"saved_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSessionStore(filePath string) *SessionStore {
|
|
||||||
store := &SessionStore{
|
|
||||||
sessions: make(map[string]*Session),
|
|
||||||
filePath: filePath,
|
|
||||||
saveInterval: 5 * time.Minute,
|
|
||||||
stopChan: make(chan struct{}),
|
|
||||||
}
|
|
||||||
|
|
||||||
store.loadFromFile()
|
|
||||||
store.startPeriodicSave()
|
|
||||||
|
|
||||||
return store
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SessionStore) generateSessionID() string {
|
|
||||||
bytes := make([]byte, SessionIDLength)
|
|
||||||
rand.Read(bytes)
|
|
||||||
return hex.EncodeToString(bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SessionStore) Create(userID int, username, email string) *Session {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
session := &Session{
|
|
||||||
ID: s.generateSessionID(),
|
|
||||||
UserID: userID,
|
|
||||||
Username: username,
|
|
||||||
Email: email,
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
ExpiresAt: time.Now().Add(DefaultExpiration),
|
|
||||||
LastSeen: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
s.sessions[session.ID] = session
|
|
||||||
return session
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SessionStore) Get(sessionID string) (*Session, bool) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
|
|
||||||
session, exists := s.sessions[sessionID]
|
|
||||||
if !exists {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
if time.Now().After(session.ExpiresAt) {
|
|
||||||
delete(s.sessions, sessionID)
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
return session, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SessionStore) Update(sessionID string) bool {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
session, exists := s.sessions[sessionID]
|
|
||||||
if !exists {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if time.Now().After(session.ExpiresAt) {
|
|
||||||
delete(s.sessions, sessionID)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
session.LastSeen = time.Now()
|
|
||||||
session.ExpiresAt = time.Now().Add(DefaultExpiration)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SessionStore) Delete(sessionID string) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
delete(s.sessions, sessionID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SessionStore) Cleanup() {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
for id, session := range s.sessions {
|
|
||||||
if now.After(session.ExpiresAt) {
|
|
||||||
delete(s.sessions, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SessionStore) loadFromFile() {
|
|
||||||
if s.filePath == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := os.ReadFile(s.filePath)
|
|
||||||
if err != nil {
|
|
||||||
return // File might not exist yet
|
|
||||||
}
|
|
||||||
|
|
||||||
var persisted persistedData
|
|
||||||
if err := json.Unmarshal(data, &persisted); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
for id, session := range persisted.Sessions {
|
|
||||||
if now.Before(session.ExpiresAt) {
|
|
||||||
s.sessions[id] = session
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SessionStore) saveToFile() error {
|
|
||||||
if s.filePath == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
s.mu.RLock()
|
|
||||||
sessionsCopy := make(map[string]*Session)
|
|
||||||
maps.Copy(sessionsCopy, s.sessions)
|
|
||||||
s.mu.RUnlock()
|
|
||||||
|
|
||||||
data := persistedData{
|
|
||||||
Sessions: sessionsCopy,
|
|
||||||
SavedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonData, err := json.MarshalIndent(data, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return os.WriteFile(s.filePath, jsonData, 0600)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SessionStore) startPeriodicSave() {
|
|
||||||
go func() {
|
|
||||||
ticker := time.NewTicker(s.saveInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
s.Cleanup()
|
|
||||||
s.saveToFile()
|
|
||||||
case <-s.stopChan:
|
|
||||||
s.saveToFile()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SessionStore) Close() error {
|
|
||||||
close(s.stopChan)
|
|
||||||
return s.saveToFile()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SessionStore) Stats() (total, active int) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
total = len(s.sessions)
|
|
||||||
|
|
||||||
for _, session := range s.sessions {
|
|
||||||
if now.Before(session.ExpiresAt) {
|
|
||||||
active++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
@ -1,3 +1,23 @@
|
|||||||
|
// Package csrf provides Cross-Site Request Forgery (CSRF) protection
|
||||||
|
// with session-based token storage and form helpers.
|
||||||
|
//
|
||||||
|
// # Basic Usage
|
||||||
|
//
|
||||||
|
// // Generate token and store in session
|
||||||
|
// token := csrf.GenerateToken(ctx, authManager)
|
||||||
|
//
|
||||||
|
// // In templates - generate hidden input field
|
||||||
|
// hiddenField := csrf.HiddenField(ctx, authManager)
|
||||||
|
//
|
||||||
|
// // Verify form submission
|
||||||
|
// if !csrf.ValidateToken(ctx, authManager, formToken) {
|
||||||
|
// // Handle CSRF validation failure
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// # Middleware Integration
|
||||||
|
//
|
||||||
|
// // Add CSRF middleware to protected routes
|
||||||
|
// r.Use(middleware.CSRF(authManager))
|
||||||
package csrf
|
package csrf
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -9,6 +29,7 @@ import (
|
|||||||
|
|
||||||
"dk/internal/auth"
|
"dk/internal/auth"
|
||||||
"dk/internal/router"
|
"dk/internal/router"
|
||||||
|
"dk/internal/session"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
@ -22,9 +43,9 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// GetCurrentSession retrieves the session from context (mirrors middleware function)
|
// GetCurrentSession retrieves the session from context (mirrors middleware function)
|
||||||
func GetCurrentSession(ctx router.Ctx) *auth.Session {
|
func GetCurrentSession(ctx router.Ctx) *session.Session {
|
||||||
if session, ok := ctx.UserValue(SessionCtxKey).(*auth.Session); ok {
|
if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok {
|
||||||
return session
|
return sess
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -97,23 +118,17 @@ func ValidateToken(ctx router.Ctx, authManager *auth.AuthManager, submittedToken
|
|||||||
}
|
}
|
||||||
|
|
||||||
// StoreToken saves a CSRF token in the session
|
// StoreToken saves a CSRF token in the session
|
||||||
func StoreToken(session *auth.Session, token string) {
|
func StoreToken(sess *session.Session, token string) {
|
||||||
if session.Data == nil {
|
sess.Set(SessionKey, token)
|
||||||
session.Data = make(map[string]any)
|
|
||||||
}
|
|
||||||
session.Data[SessionKey] = token
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStoredToken retrieves the CSRF token from session
|
// GetStoredToken retrieves the CSRF token from session
|
||||||
func GetStoredToken(session *auth.Session) string {
|
func GetStoredToken(sess *session.Session) string {
|
||||||
if session.Data == nil {
|
if token, ok := sess.Get(SessionKey); ok {
|
||||||
return ""
|
if tokenStr, ok := token.(string); ok {
|
||||||
|
return tokenStr
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if token, ok := session.Data[SessionKey].(string); ok {
|
|
||||||
return token
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,14 +4,13 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"dk/internal/auth"
|
"dk/internal/session"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGenerateToken(t *testing.T) {
|
func TestGenerateToken(t *testing.T) {
|
||||||
// Create a mock session
|
sess := &session.Session{
|
||||||
session := &auth.Session{
|
|
||||||
ID: "test-session",
|
ID: "test-session",
|
||||||
UserID: 1,
|
UserID: 1,
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
@ -22,27 +21,23 @@ func TestGenerateToken(t *testing.T) {
|
|||||||
Data: make(map[string]any),
|
Data: make(map[string]any),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create mock context
|
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
ctx.SetUserValue(SessionCtxKey, session)
|
ctx.SetUserValue(SessionCtxKey, sess)
|
||||||
|
|
||||||
// Generate token
|
|
||||||
token := GenerateToken(ctx, nil)
|
token := GenerateToken(ctx, nil)
|
||||||
|
|
||||||
if token == "" {
|
if token == "" {
|
||||||
t.Error("Expected non-empty token")
|
t.Error("Expected non-empty token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that token was stored in session
|
storedToken := GetStoredToken(sess)
|
||||||
storedToken := GetStoredToken(session)
|
|
||||||
if storedToken != token {
|
if storedToken != token {
|
||||||
t.Errorf("Expected stored token %s, got %s", token, storedToken)
|
t.Errorf("Expected stored token %s, got %s", token, storedToken)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateToken(t *testing.T) {
|
func TestValidateToken(t *testing.T) {
|
||||||
// Create session with token
|
sess := &session.Session{
|
||||||
session := &auth.Session{
|
|
||||||
ID: "test-session",
|
ID: "test-session",
|
||||||
UserID: 1,
|
UserID: 1,
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
@ -51,19 +46,16 @@ func TestValidateToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
ctx.SetUserValue(SessionCtxKey, session)
|
ctx.SetUserValue(SessionCtxKey, sess)
|
||||||
|
|
||||||
// Valid token should pass
|
|
||||||
if !ValidateToken(ctx, nil, "test-token") {
|
if !ValidateToken(ctx, nil, "test-token") {
|
||||||
t.Error("Expected valid token to pass validation")
|
t.Error("Expected valid token to pass validation")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invalid token should fail
|
|
||||||
if ValidateToken(ctx, nil, "wrong-token") {
|
if ValidateToken(ctx, nil, "wrong-token") {
|
||||||
t.Error("Expected invalid token to fail validation")
|
t.Error("Expected invalid token to fail validation")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Empty token should fail
|
|
||||||
if ValidateToken(ctx, nil, "") {
|
if ValidateToken(ctx, nil, "") {
|
||||||
t.Error("Expected empty token to fail validation")
|
t.Error("Expected empty token to fail validation")
|
||||||
}
|
}
|
||||||
@ -72,14 +64,13 @@ func TestValidateToken(t *testing.T) {
|
|||||||
func TestValidateTokenNoSession(t *testing.T) {
|
func TestValidateTokenNoSession(t *testing.T) {
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
// No session should fail validation
|
|
||||||
if ValidateToken(ctx, nil, "any-token") {
|
if ValidateToken(ctx, nil, "any-token") {
|
||||||
t.Error("Expected validation to fail with no session")
|
t.Error("Expected validation to fail with no session")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHiddenField(t *testing.T) {
|
func TestHiddenField(t *testing.T) {
|
||||||
session := &auth.Session{
|
sess := &session.Session{
|
||||||
ID: "test-session",
|
ID: "test-session",
|
||||||
UserID: 1,
|
UserID: 1,
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
@ -88,7 +79,7 @@ func TestHiddenField(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
ctx.SetUserValue(SessionCtxKey, session)
|
ctx.SetUserValue(SessionCtxKey, sess)
|
||||||
|
|
||||||
field := HiddenField(ctx, nil)
|
field := HiddenField(ctx, nil)
|
||||||
expected := `<input type="hidden" name="_csrf_token" value="test-token">`
|
expected := `<input type="hidden" name="_csrf_token" value="test-token">`
|
||||||
@ -102,13 +93,13 @@ func TestHiddenFieldNoSession(t *testing.T) {
|
|||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
field := HiddenField(ctx, nil)
|
field := HiddenField(ctx, nil)
|
||||||
if field != "" {
|
if field == "" {
|
||||||
t.Errorf("Expected empty field with no session, got %s", field)
|
t.Error("Expected non-empty field for guest user with cookie-based token")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTokenMeta(t *testing.T) {
|
func TestTokenMeta(t *testing.T) {
|
||||||
session := &auth.Session{
|
sess := &session.Session{
|
||||||
ID: "test-session",
|
ID: "test-session",
|
||||||
UserID: 1,
|
UserID: 1,
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
@ -117,7 +108,7 @@ func TestTokenMeta(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
ctx.SetUserValue(SessionCtxKey, session)
|
ctx.SetUserValue(SessionCtxKey, sess)
|
||||||
|
|
||||||
meta := TokenMeta(ctx, nil)
|
meta := TokenMeta(ctx, nil)
|
||||||
expected := `<meta name="csrf-token" content="test-token">`
|
expected := `<meta name="csrf-token" content="test-token">`
|
||||||
@ -128,30 +119,30 @@ func TestTokenMeta(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStoreAndGetToken(t *testing.T) {
|
func TestStoreAndGetToken(t *testing.T) {
|
||||||
session := &auth.Session{
|
sess := &session.Session{
|
||||||
Data: make(map[string]any),
|
Data: make(map[string]any),
|
||||||
}
|
}
|
||||||
|
|
||||||
token := "test-token"
|
token := "test-token"
|
||||||
StoreToken(session, token)
|
StoreToken(sess, token)
|
||||||
|
|
||||||
retrieved := GetStoredToken(session)
|
retrieved := GetStoredToken(sess)
|
||||||
if retrieved != token {
|
if retrieved != token {
|
||||||
t.Errorf("Expected %s, got %s", token, retrieved)
|
t.Errorf("Expected %s, got %s", token, retrieved)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetStoredTokenNoData(t *testing.T) {
|
func TestGetStoredTokenNoData(t *testing.T) {
|
||||||
session := &auth.Session{}
|
sess := &session.Session{}
|
||||||
|
|
||||||
token := GetStoredToken(session)
|
token := GetStoredToken(sess)
|
||||||
if token != "" {
|
if token != "" {
|
||||||
t.Errorf("Expected empty token, got %s", token)
|
t.Errorf("Expected empty token, got %s", token)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateFormToken(t *testing.T) {
|
func TestValidateFormToken(t *testing.T) {
|
||||||
session := &auth.Session{
|
sess := &session.Session{
|
||||||
ID: "test-session",
|
ID: "test-session",
|
||||||
UserID: 1,
|
UserID: 1,
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
@ -160,16 +151,14 @@ func TestValidateFormToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
ctx.SetUserValue(SessionCtxKey, session)
|
ctx.SetUserValue(SessionCtxKey, sess)
|
||||||
|
|
||||||
// Set form data
|
|
||||||
ctx.PostArgs().Set(TokenFieldName, "test-token")
|
ctx.PostArgs().Set(TokenFieldName, "test-token")
|
||||||
|
|
||||||
if !ValidateFormToken(ctx, nil) {
|
if !ValidateFormToken(ctx, nil) {
|
||||||
t.Error("Expected form token validation to pass")
|
t.Error("Expected form token validation to pass")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test with wrong token
|
|
||||||
ctx.PostArgs().Set(TokenFieldName, "wrong-token")
|
ctx.PostArgs().Set(TokenFieldName, "wrong-token")
|
||||||
|
|
||||||
if ValidateFormToken(ctx, nil) {
|
if ValidateFormToken(ctx, nil) {
|
||||||
|
@ -1,29 +0,0 @@
|
|||||||
// Package csrf provides Cross-Site Request Forgery (CSRF) protection
|
|
||||||
// with session-based token storage and form helpers.
|
|
||||||
//
|
|
||||||
// # Basic Usage
|
|
||||||
//
|
|
||||||
// // Generate token and store in session
|
|
||||||
// token := csrf.GenerateToken(ctx, authManager)
|
|
||||||
//
|
|
||||||
// // In templates - generate hidden input field
|
|
||||||
// hiddenField := csrf.HiddenField(ctx, authManager)
|
|
||||||
//
|
|
||||||
// // Verify form submission
|
|
||||||
// if !csrf.ValidateToken(ctx, authManager, formToken) {
|
|
||||||
// // Handle CSRF validation failure
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// # Middleware Integration
|
|
||||||
//
|
|
||||||
// // Add CSRF middleware to protected routes
|
|
||||||
// r.Use(middleware.CSRF(authManager))
|
|
||||||
//
|
|
||||||
// # Security Features
|
|
||||||
//
|
|
||||||
// - Cryptographically secure token generation
|
|
||||||
// - Session-based token storage and validation
|
|
||||||
// - Automatic token rotation on successful validation
|
|
||||||
// - Protection against timing attacks with constant-time comparison
|
|
||||||
// - Integration with existing authentication system
|
|
||||||
package csrf
|
|
@ -3,30 +3,26 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"dk/internal/auth"
|
"dk/internal/auth"
|
||||||
"dk/internal/router"
|
"dk/internal/router"
|
||||||
|
"dk/internal/session"
|
||||||
"dk/internal/users"
|
"dk/internal/users"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Auth creates an authentication middleware
|
|
||||||
func Auth(authManager *auth.AuthManager) router.Middleware {
|
func Auth(authManager *auth.AuthManager) router.Middleware {
|
||||||
return func(next router.Handler) router.Handler {
|
return func(next router.Handler) router.Handler {
|
||||||
return func(ctx router.Ctx, params []string) {
|
return func(ctx router.Ctx, params []string) {
|
||||||
sessionID := auth.GetSessionCookie(ctx)
|
sessionID := auth.GetSessionCookie(ctx)
|
||||||
|
|
||||||
if sessionID != "" {
|
if sessionID != "" {
|
||||||
if session, exists := authManager.GetSession(sessionID); exists {
|
if sess, exists := authManager.GetSession(sessionID); exists {
|
||||||
// Update session activity
|
|
||||||
authManager.UpdateSession(sessionID)
|
authManager.UpdateSession(sessionID)
|
||||||
|
|
||||||
// Get the full user object
|
user, err := users.Find(sess.UserID)
|
||||||
user, err := users.Find(session.UserID)
|
|
||||||
if err == nil && user != nil {
|
if err == nil && user != nil {
|
||||||
// Store session and user info in context
|
ctx.SetUserValue("session", sess)
|
||||||
ctx.SetUserValue("session", session)
|
|
||||||
ctx.SetUserValue("user", user)
|
ctx.SetUserValue("user", user)
|
||||||
|
|
||||||
// Refresh the cookie
|
|
||||||
auth.SetSessionCookie(ctx, sessionID)
|
auth.SetSessionCookie(ctx, sessionID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -37,7 +33,6 @@ func Auth(authManager *auth.AuthManager) router.Middleware {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequireAuth enforces authentication - redirect defaults to "/login"
|
|
||||||
func RequireAuth(paths ...string) router.Middleware {
|
func RequireAuth(paths ...string) router.Middleware {
|
||||||
redirect := "/login"
|
redirect := "/login"
|
||||||
if len(paths) > 0 && paths[0] != "" {
|
if len(paths) > 0 && paths[0] != "" {
|
||||||
@ -56,7 +51,6 @@ func RequireAuth(paths ...string) router.Middleware {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequireGuest enforces no authentication - redirect defaults to "/"
|
|
||||||
func RequireGuest(paths ...string) router.Middleware {
|
func RequireGuest(paths ...string) router.Middleware {
|
||||||
redirect := "/"
|
redirect := "/"
|
||||||
if len(paths) > 0 && paths[0] != "" {
|
if len(paths) > 0 && paths[0] != "" {
|
||||||
@ -74,13 +68,11 @@ func RequireGuest(paths ...string) router.Middleware {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsAuthenticated checks if the current request has a valid session
|
|
||||||
func IsAuthenticated(ctx router.Ctx) bool {
|
func IsAuthenticated(ctx router.Ctx) bool {
|
||||||
_, exists := ctx.UserValue("user").(*users.User)
|
_, exists := ctx.UserValue("user").(*users.User)
|
||||||
return exists
|
return exists
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCurrentUser returns the current authenticated user, or nil if not authenticated
|
|
||||||
func GetCurrentUser(ctx router.Ctx) *users.User {
|
func GetCurrentUser(ctx router.Ctx) *users.User {
|
||||||
if user, ok := ctx.UserValue("user").(*users.User); ok {
|
if user, ok := ctx.UserValue("user").(*users.User); ok {
|
||||||
return user
|
return user
|
||||||
@ -88,25 +80,21 @@ func GetCurrentUser(ctx router.Ctx) *users.User {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCurrentSession returns the current session, or nil if not authenticated
|
func GetCurrentSession(ctx router.Ctx) *session.Session {
|
||||||
func GetCurrentSession(ctx router.Ctx) *auth.Session {
|
if sess, ok := ctx.UserValue("session").(*session.Session); ok {
|
||||||
if session, ok := ctx.UserValue("session").(*auth.Session); ok {
|
return sess
|
||||||
return session
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Login creates a session and sets the cookie
|
|
||||||
func Login(ctx router.Ctx, authManager *auth.AuthManager, user *users.User) {
|
func Login(ctx router.Ctx, authManager *auth.AuthManager, user *users.User) {
|
||||||
session := authManager.CreateSession(user)
|
sess := authManager.CreateSession(user)
|
||||||
auth.SetSessionCookie(ctx, session.ID)
|
auth.SetSessionCookie(ctx, sess.ID)
|
||||||
|
|
||||||
// Set in context for immediate use
|
ctx.SetUserValue("session", sess)
|
||||||
ctx.SetUserValue("session", session)
|
|
||||||
ctx.SetUserValue("user", user)
|
ctx.SetUserValue("user", user)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Logout destroys the session and clears the cookie
|
|
||||||
func Logout(ctx router.Ctx, authManager *auth.AuthManager) {
|
func Logout(ctx router.Ctx, authManager *auth.AuthManager) {
|
||||||
sessionID := auth.GetSessionCookie(ctx)
|
sessionID := auth.GetSessionCookie(ctx)
|
||||||
if sessionID != "" {
|
if sessionID != "" {
|
||||||
@ -115,7 +103,6 @@ func Logout(ctx router.Ctx, authManager *auth.AuthManager) {
|
|||||||
|
|
||||||
auth.DeleteSessionCookie(ctx)
|
auth.DeleteSessionCookie(ctx)
|
||||||
|
|
||||||
// Clear from context
|
|
||||||
ctx.SetUserValue("session", nil)
|
ctx.SetUserValue("session", nil)
|
||||||
ctx.SetUserValue("user", nil)
|
ctx.SetUserValue("user", nil)
|
||||||
}
|
}
|
56
internal/session/flash.go
Normal file
56
internal/session/flash.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
type FlashMessage struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) SetFlash(key string, value any) {
|
||||||
|
if s.Data == nil {
|
||||||
|
s.Data = make(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
flashData, ok := s.Data["_flash"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
flashData = make(map[string]any)
|
||||||
|
}
|
||||||
|
flashData[key] = value
|
||||||
|
s.Data["_flash"] = flashData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) GetFlash(key string) (any, bool) {
|
||||||
|
if s.Data == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
flashData, ok := s.Data["_flash"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
value, exists := flashData[key]
|
||||||
|
if exists {
|
||||||
|
delete(flashData, key)
|
||||||
|
if len(flashData) == 0 {
|
||||||
|
delete(s.Data, "_flash")
|
||||||
|
} else {
|
||||||
|
s.Data["_flash"] = flashData
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return value, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) GetAllFlash() map[string]any {
|
||||||
|
if s.Data == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
flashData, ok := s.Data["_flash"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(s.Data, "_flash")
|
||||||
|
return flashData
|
||||||
|
}
|
74
internal/session/session.go
Normal file
74
internal/session/session.go
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
// Package session provides session management functionality.
|
||||||
|
// It includes session storage, flash messages, and data persistence.
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultExpiration = 24 * time.Hour
|
||||||
|
IDLength = 32
|
||||||
|
)
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
ID string `json:"-"`
|
||||||
|
UserID int `json:"user_id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
LastSeen time.Time `json:"last_seen"`
|
||||||
|
Data map[string]any `json:"data,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(userID int, username, email string) *Session {
|
||||||
|
return &Session{
|
||||||
|
ID: generateID(),
|
||||||
|
UserID: userID,
|
||||||
|
Username: username,
|
||||||
|
Email: email,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
ExpiresAt: time.Now().Add(DefaultExpiration),
|
||||||
|
LastSeen: time.Now(),
|
||||||
|
Data: make(map[string]any),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) IsExpired() bool {
|
||||||
|
return time.Now().After(s.ExpiresAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) Touch() {
|
||||||
|
s.LastSeen = time.Now()
|
||||||
|
s.ExpiresAt = time.Now().Add(DefaultExpiration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) Set(key string, value any) {
|
||||||
|
if s.Data == nil {
|
||||||
|
s.Data = make(map[string]any)
|
||||||
|
}
|
||||||
|
s.Data[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) Get(key string) (any, bool) {
|
||||||
|
if s.Data == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
value, exists := s.Data[key]
|
||||||
|
return value, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) Delete(key string) {
|
||||||
|
if s.Data != nil {
|
||||||
|
delete(s.Data, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateID() string {
|
||||||
|
bytes := make([]byte, IDLength)
|
||||||
|
rand.Read(bytes)
|
||||||
|
return hex.EncodeToString(bytes)
|
||||||
|
}
|
161
internal/session/store.go
Normal file
161
internal/session/store.go
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"maps"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Store struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
sessions map[string]*Session
|
||||||
|
filePath string
|
||||||
|
saveInterval time.Duration
|
||||||
|
stopChan chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type persistedData struct {
|
||||||
|
Sessions map[string]*Session `json:"sessions"`
|
||||||
|
SavedAt time.Time `json:"saved_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStore(filePath string) *Store {
|
||||||
|
store := &Store{
|
||||||
|
sessions: make(map[string]*Session),
|
||||||
|
filePath: filePath,
|
||||||
|
saveInterval: 5 * time.Minute,
|
||||||
|
stopChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
store.loadFromFile()
|
||||||
|
store.startPeriodicSave()
|
||||||
|
|
||||||
|
return store
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Save(session *Session) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.sessions[session.ID] = session
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Get(sessionID string) (*Session, bool) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
session, exists := s.sessions[sessionID]
|
||||||
|
if !exists {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.IsExpired() {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return session, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Delete(sessionID string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.sessions, sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Cleanup() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
for id, session := range s.sessions {
|
||||||
|
if session.IsExpired() {
|
||||||
|
delete(s.sessions, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Stats() (total, active int) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
total = len(s.sessions)
|
||||||
|
for _, session := range s.sessions {
|
||||||
|
if !session.IsExpired() {
|
||||||
|
active++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) loadFromFile() {
|
||||||
|
if s.filePath == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(s.filePath)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var persisted persistedData
|
||||||
|
if err := json.Unmarshal(data, &persisted); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
for id, session := range persisted.Sessions {
|
||||||
|
if !session.IsExpired() {
|
||||||
|
session.ID = id
|
||||||
|
s.sessions[id] = session
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) saveToFile() error {
|
||||||
|
if s.filePath == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.RLock()
|
||||||
|
sessionsCopy := make(map[string]*Session, len(s.sessions))
|
||||||
|
maps.Copy(sessionsCopy, s.sessions)
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
data := persistedData{
|
||||||
|
Sessions: sessionsCopy,
|
||||||
|
SavedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.MarshalIndent(data, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.WriteFile(s.filePath, jsonData, 0600)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) startPeriodicSave() {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(s.saveInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
s.Cleanup()
|
||||||
|
s.saveToFile()
|
||||||
|
case <-s.stopChan:
|
||||||
|
s.saveToFile()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Close() error {
|
||||||
|
close(s.stopChan)
|
||||||
|
return s.saveToFile()
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user