separate session into its own package, clean up some docs
This commit is contained in:
parent
b8b77351d0
commit
4a73b7cc0d
@ -1,33 +1,30 @@
|
||||
// Package auth provides authentication and session management functionality.
|
||||
// It includes secure session storage with in-memory caching and JSON persistence,
|
||||
// user authentication against the database, and secure cookie handling.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"dk/internal/password"
|
||||
"dk/internal/session"
|
||||
"dk/internal/users"
|
||||
)
|
||||
|
||||
// Manager is the global singleton instance
|
||||
var Manager *AuthManager
|
||||
|
||||
// AuthManager is a wrapper for the session store to add
|
||||
// authentication tools over the store itself
|
||||
type AuthManager struct {
|
||||
store *SessionStore
|
||||
store *session.Store
|
||||
}
|
||||
|
||||
// Init initializes the global auth manager (auth.Manager)
|
||||
func Init(sessionsFilePath string) {
|
||||
Manager = &AuthManager{
|
||||
store: NewSessionStore(sessionsFilePath),
|
||||
store: session.NewStore(sessionsFilePath),
|
||||
}
|
||||
}
|
||||
|
||||
// Authenticate checks for the usernaname or email, then verifies the plain password
|
||||
// against the stored hash.
|
||||
func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*users.User, error) {
|
||||
var user *users.User
|
||||
var err error
|
||||
|
||||
// Try to find user by username first
|
||||
user, err = users.GetByUsername(usernameOrEmail)
|
||||
if err != nil {
|
||||
user, err = users.GetByEmail(usernameOrEmail)
|
||||
@ -47,16 +44,25 @@ func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*use
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (am *AuthManager) CreateSession(user *users.User) *Session {
|
||||
return am.store.Create(user.ID, user.Username, user.Email)
|
||||
func (am *AuthManager) CreateSession(user *users.User) *session.Session {
|
||||
sess := session.New(user.ID, user.Username, user.Email)
|
||||
am.store.Save(sess)
|
||||
return sess
|
||||
}
|
||||
|
||||
func (am *AuthManager) GetSession(sessionID string) (*Session, bool) {
|
||||
func (am *AuthManager) GetSession(sessionID string) (*session.Session, bool) {
|
||||
return am.store.Get(sessionID)
|
||||
}
|
||||
|
||||
func (am *AuthManager) UpdateSession(sessionID string) bool {
|
||||
return am.store.Update(sessionID)
|
||||
sess, exists := am.store.Get(sessionID)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
sess.Touch()
|
||||
am.store.Save(sess)
|
||||
return true
|
||||
}
|
||||
|
||||
func (am *AuthManager) DeleteSession(sessionID string) {
|
||||
@ -71,124 +77,6 @@ func (am *AuthManager) Close() error {
|
||||
return am.store.Close()
|
||||
}
|
||||
|
||||
// SetFlash stores a flash message in the session that will be removed after retrieval
|
||||
func (am *AuthManager) SetFlash(sessionID, key string, value any) bool {
|
||||
session, exists := am.store.Get(sessionID)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
am.store.mu.Lock()
|
||||
defer am.store.mu.Unlock()
|
||||
|
||||
if session.Data == nil {
|
||||
session.Data = make(map[string]any)
|
||||
}
|
||||
|
||||
// Store flash messages under a special key
|
||||
flashData, ok := session.Data["_flash"].(map[string]any)
|
||||
if !ok {
|
||||
flashData = make(map[string]any)
|
||||
}
|
||||
flashData[key] = value
|
||||
session.Data["_flash"] = flashData
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GetFlash retrieves and removes a flash message from the session
|
||||
func (am *AuthManager) GetFlash(sessionID, key string) (any, bool) {
|
||||
session, exists := am.store.Get(sessionID)
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
am.store.mu.Lock()
|
||||
defer am.store.mu.Unlock()
|
||||
|
||||
if session.Data == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
flashData, ok := session.Data["_flash"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
value, exists := flashData[key]
|
||||
if exists {
|
||||
delete(flashData, key)
|
||||
if len(flashData) == 0 {
|
||||
delete(session.Data, "_flash")
|
||||
} else {
|
||||
session.Data["_flash"] = flashData
|
||||
}
|
||||
}
|
||||
|
||||
return value, exists
|
||||
}
|
||||
|
||||
// GetAllFlash retrieves and removes all flash messages from the session
|
||||
func (am *AuthManager) GetAllFlash(sessionID string) map[string]any {
|
||||
session, exists := am.store.Get(sessionID)
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
am.store.mu.Lock()
|
||||
defer am.store.mu.Unlock()
|
||||
|
||||
if session.Data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
flashData, ok := session.Data["_flash"].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove flash data from session
|
||||
delete(session.Data, "_flash")
|
||||
|
||||
return flashData
|
||||
}
|
||||
|
||||
// SetSessionData stores arbitrary data in the session
|
||||
func (am *AuthManager) SetSessionData(sessionID, key string, value any) bool {
|
||||
session, exists := am.store.Get(sessionID)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
am.store.mu.Lock()
|
||||
defer am.store.mu.Unlock()
|
||||
|
||||
if session.Data == nil {
|
||||
session.Data = make(map[string]any)
|
||||
}
|
||||
|
||||
session.Data[key] = value
|
||||
return true
|
||||
}
|
||||
|
||||
// GetSessionData retrieves data from the session
|
||||
func (am *AuthManager) GetSessionData(sessionID, key string) (any, bool) {
|
||||
session, exists := am.store.Get(sessionID)
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
am.store.mu.RLock()
|
||||
defer am.store.mu.RUnlock()
|
||||
|
||||
if session.Data == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
value, exists := session.Data[key]
|
||||
return value, exists
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidCredentials = &AuthError{"invalid username/email or password"}
|
||||
ErrSessionNotFound = &AuthError{"session not found"}
|
||||
|
@ -2,18 +2,21 @@ package auth
|
||||
|
||||
import (
|
||||
"dk/internal/cookies"
|
||||
"dk/internal/session"
|
||||
"dk/internal/utils"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const SessionCookieName = "dk_session"
|
||||
|
||||
func SetSessionCookie(ctx *fasthttp.RequestCtx, sessionID string) {
|
||||
cookies.SetSecureCookie(ctx, cookies.CookieOptions{
|
||||
Name: SessionCookieName,
|
||||
Value: sessionID,
|
||||
Path: "/",
|
||||
Expires: time.Now().Add(DefaultExpiration),
|
||||
Expires: time.Now().Add(session.DefaultExpiration),
|
||||
HTTPOnly: true,
|
||||
Secure: utils.IsHTTPS(ctx),
|
||||
SameSite: "lax",
|
||||
|
@ -1,4 +0,0 @@
|
||||
// Package auth provides authentication and session management functionality.
|
||||
// It includes secure session storage with in-memory caching and JSON persistence,
|
||||
// user authentication against the database, and secure cookie handling.
|
||||
package auth
|
@ -2,46 +2,52 @@ package auth
|
||||
|
||||
import (
|
||||
"dk/internal/router"
|
||||
"dk/internal/session"
|
||||
)
|
||||
|
||||
// FlashMessage represents a flash message with type and content
|
||||
type FlashMessage struct {
|
||||
Type string `json:"type"` // "error", "success", "warning", "info"
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// SetFlashMessage sets a flash message for the current session
|
||||
func SetFlashMessage(ctx router.Ctx, msgType, message string) bool {
|
||||
sessionID := GetSessionCookie(ctx)
|
||||
if sessionID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
return Manager.SetFlash(sessionID, "message", FlashMessage{
|
||||
sess, exists := Manager.GetSession(sessionID)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
sess.SetFlash("message", session.FlashMessage{
|
||||
Type: msgType,
|
||||
Message: message,
|
||||
})
|
||||
Manager.store.Save(sess)
|
||||
return true
|
||||
}
|
||||
|
||||
// GetFlashMessage retrieves and removes the flash message from the current session
|
||||
func GetFlashMessage(ctx router.Ctx) *FlashMessage {
|
||||
func GetFlashMessage(ctx router.Ctx) *session.FlashMessage {
|
||||
sessionID := GetSessionCookie(ctx)
|
||||
if sessionID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
value, exists := Manager.GetFlash(sessionID, "message")
|
||||
sess, exists := Manager.GetSession(sessionID)
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if msg, ok := value.(FlashMessage); ok {
|
||||
value, exists := sess.GetFlash("message")
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
Manager.store.Save(sess)
|
||||
|
||||
if msg, ok := value.(session.FlashMessage); ok {
|
||||
return &msg
|
||||
}
|
||||
|
||||
// Handle map[string]interface{} from JSON deserialization
|
||||
if msgMap, ok := value.(map[string]interface{}); ok {
|
||||
msg := &FlashMessage{}
|
||||
msg := &session.FlashMessage{}
|
||||
if t, ok := msgMap["type"].(string); ok {
|
||||
msg.Type = t
|
||||
}
|
||||
@ -54,36 +60,45 @@ func GetFlashMessage(ctx router.Ctx) *FlashMessage {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetFormData stores form data temporarily in the session (for repopulating forms after errors)
|
||||
func SetFormData(ctx router.Ctx, data map[string]string) bool {
|
||||
sessionID := GetSessionCookie(ctx)
|
||||
if sessionID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
return Manager.SetSessionData(sessionID, "form_data", data)
|
||||
sess, exists := Manager.GetSession(sessionID)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
sess.Set("form_data", data)
|
||||
Manager.store.Save(sess)
|
||||
return true
|
||||
}
|
||||
|
||||
// GetFormData retrieves and removes form data from the session
|
||||
func GetFormData(ctx router.Ctx) map[string]string {
|
||||
sessionID := GetSessionCookie(ctx)
|
||||
if sessionID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
value, exists := Manager.GetSessionData(sessionID, "form_data")
|
||||
sess, exists := Manager.GetSession(sessionID)
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear form data after retrieval
|
||||
Manager.SetSessionData(sessionID, "form_data", nil)
|
||||
value, exists := sess.Get("form_data")
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
sess.Delete("form_data")
|
||||
Manager.store.Save(sess)
|
||||
|
||||
if formData, ok := value.(map[string]string); ok {
|
||||
return formData
|
||||
}
|
||||
|
||||
// Handle map[string]interface{} from JSON deserialization
|
||||
if formMap, ok := value.(map[string]interface{}); ok {
|
||||
result := make(map[string]string)
|
||||
for k, v := range formMap {
|
||||
|
@ -1,222 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"maps"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
SessionCookieName = "dk_session"
|
||||
DefaultExpiration = 24 * time.Hour
|
||||
SessionIDLength = 32
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
ID string `json:"-"` // Exclude from JSON since it's stored as the map key
|
||||
UserID int `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
LastSeen time.Time `json:"last_seen"`
|
||||
Data map[string]any `json:"data,omitempty"` // For storing additional session data
|
||||
}
|
||||
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*Session
|
||||
filePath string
|
||||
saveInterval time.Duration
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
type persistedData struct {
|
||||
Sessions map[string]*Session `json:"sessions"`
|
||||
SavedAt time.Time `json:"saved_at"`
|
||||
}
|
||||
|
||||
func NewSessionStore(filePath string) *SessionStore {
|
||||
store := &SessionStore{
|
||||
sessions: make(map[string]*Session),
|
||||
filePath: filePath,
|
||||
saveInterval: 5 * time.Minute,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
store.loadFromFile()
|
||||
store.startPeriodicSave()
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func (s *SessionStore) generateSessionID() string {
|
||||
bytes := make([]byte, SessionIDLength)
|
||||
rand.Read(bytes)
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
func (s *SessionStore) Create(userID int, username, email string) *Session {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
session := &Session{
|
||||
ID: s.generateSessionID(),
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
Email: email,
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(DefaultExpiration),
|
||||
LastSeen: time.Now(),
|
||||
}
|
||||
|
||||
s.sessions[session.ID] = session
|
||||
return session
|
||||
}
|
||||
|
||||
func (s *SessionStore) Get(sessionID string) (*Session, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
session, exists := s.sessions[sessionID]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
delete(s.sessions, sessionID)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return session, true
|
||||
}
|
||||
|
||||
func (s *SessionStore) Update(sessionID string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
session, exists := s.sessions[sessionID]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
delete(s.sessions, sessionID)
|
||||
return false
|
||||
}
|
||||
|
||||
session.LastSeen = time.Now()
|
||||
session.ExpiresAt = time.Now().Add(DefaultExpiration)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *SessionStore) Delete(sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.sessions, sessionID)
|
||||
}
|
||||
|
||||
func (s *SessionStore) Cleanup() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for id, session := range s.sessions {
|
||||
if now.After(session.ExpiresAt) {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SessionStore) loadFromFile() {
|
||||
if s.filePath == "" {
|
||||
return
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(s.filePath)
|
||||
if err != nil {
|
||||
return // File might not exist yet
|
||||
}
|
||||
|
||||
var persisted persistedData
|
||||
if err := json.Unmarshal(data, &persisted); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for id, session := range persisted.Sessions {
|
||||
if now.Before(session.ExpiresAt) {
|
||||
s.sessions[id] = session
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SessionStore) saveToFile() error {
|
||||
if s.filePath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
sessionsCopy := make(map[string]*Session)
|
||||
maps.Copy(sessionsCopy, s.sessions)
|
||||
s.mu.RUnlock()
|
||||
|
||||
data := persistedData{
|
||||
Sessions: sessionsCopy,
|
||||
SavedAt: time.Now(),
|
||||
}
|
||||
|
||||
jsonData, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(s.filePath, jsonData, 0600)
|
||||
}
|
||||
|
||||
func (s *SessionStore) startPeriodicSave() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(s.saveInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.Cleanup()
|
||||
s.saveToFile()
|
||||
case <-s.stopChan:
|
||||
s.saveToFile()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *SessionStore) Close() error {
|
||||
close(s.stopChan)
|
||||
return s.saveToFile()
|
||||
}
|
||||
|
||||
func (s *SessionStore) Stats() (total, active int) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
total = len(s.sessions)
|
||||
|
||||
for _, session := range s.sessions {
|
||||
if now.Before(session.ExpiresAt) {
|
||||
active++
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
@ -1,3 +1,23 @@
|
||||
// Package csrf provides Cross-Site Request Forgery (CSRF) protection
|
||||
// with session-based token storage and form helpers.
|
||||
//
|
||||
// # Basic Usage
|
||||
//
|
||||
// // Generate token and store in session
|
||||
// token := csrf.GenerateToken(ctx, authManager)
|
||||
//
|
||||
// // In templates - generate hidden input field
|
||||
// hiddenField := csrf.HiddenField(ctx, authManager)
|
||||
//
|
||||
// // Verify form submission
|
||||
// if !csrf.ValidateToken(ctx, authManager, formToken) {
|
||||
// // Handle CSRF validation failure
|
||||
// }
|
||||
//
|
||||
// # Middleware Integration
|
||||
//
|
||||
// // Add CSRF middleware to protected routes
|
||||
// r.Use(middleware.CSRF(authManager))
|
||||
package csrf
|
||||
|
||||
import (
|
||||
@ -9,6 +29,7 @@ import (
|
||||
|
||||
"dk/internal/auth"
|
||||
"dk/internal/router"
|
||||
"dk/internal/session"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
@ -22,9 +43,9 @@ const (
|
||||
)
|
||||
|
||||
// GetCurrentSession retrieves the session from context (mirrors middleware function)
|
||||
func GetCurrentSession(ctx router.Ctx) *auth.Session {
|
||||
if session, ok := ctx.UserValue(SessionCtxKey).(*auth.Session); ok {
|
||||
return session
|
||||
func GetCurrentSession(ctx router.Ctx) *session.Session {
|
||||
if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok {
|
||||
return sess
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -97,23 +118,17 @@ func ValidateToken(ctx router.Ctx, authManager *auth.AuthManager, submittedToken
|
||||
}
|
||||
|
||||
// StoreToken saves a CSRF token in the session
|
||||
func StoreToken(session *auth.Session, token string) {
|
||||
if session.Data == nil {
|
||||
session.Data = make(map[string]any)
|
||||
}
|
||||
session.Data[SessionKey] = token
|
||||
func StoreToken(sess *session.Session, token string) {
|
||||
sess.Set(SessionKey, token)
|
||||
}
|
||||
|
||||
// GetStoredToken retrieves the CSRF token from session
|
||||
func GetStoredToken(session *auth.Session) string {
|
||||
if session.Data == nil {
|
||||
return ""
|
||||
func GetStoredToken(sess *session.Session) string {
|
||||
if token, ok := sess.Get(SessionKey); ok {
|
||||
if tokenStr, ok := token.(string); ok {
|
||||
return tokenStr
|
||||
}
|
||||
|
||||
if token, ok := session.Data[SessionKey].(string); ok {
|
||||
return token
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
|
@ -4,14 +4,13 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"dk/internal/auth"
|
||||
"dk/internal/session"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestGenerateToken(t *testing.T) {
|
||||
// Create a mock session
|
||||
session := &auth.Session{
|
||||
sess := &session.Session{
|
||||
ID: "test-session",
|
||||
UserID: 1,
|
||||
Username: "testuser",
|
||||
@ -22,27 +21,23 @@ func TestGenerateToken(t *testing.T) {
|
||||
Data: make(map[string]any),
|
||||
}
|
||||
|
||||
// Create mock context
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue(SessionCtxKey, session)
|
||||
ctx.SetUserValue(SessionCtxKey, sess)
|
||||
|
||||
// Generate token
|
||||
token := GenerateToken(ctx, nil)
|
||||
|
||||
if token == "" {
|
||||
t.Error("Expected non-empty token")
|
||||
}
|
||||
|
||||
// Check that token was stored in session
|
||||
storedToken := GetStoredToken(session)
|
||||
storedToken := GetStoredToken(sess)
|
||||
if storedToken != token {
|
||||
t.Errorf("Expected stored token %s, got %s", token, storedToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateToken(t *testing.T) {
|
||||
// Create session with token
|
||||
session := &auth.Session{
|
||||
sess := &session.Session{
|
||||
ID: "test-session",
|
||||
UserID: 1,
|
||||
Username: "testuser",
|
||||
@ -51,19 +46,16 @@ func TestValidateToken(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue(SessionCtxKey, session)
|
||||
ctx.SetUserValue(SessionCtxKey, sess)
|
||||
|
||||
// Valid token should pass
|
||||
if !ValidateToken(ctx, nil, "test-token") {
|
||||
t.Error("Expected valid token to pass validation")
|
||||
}
|
||||
|
||||
// Invalid token should fail
|
||||
if ValidateToken(ctx, nil, "wrong-token") {
|
||||
t.Error("Expected invalid token to fail validation")
|
||||
}
|
||||
|
||||
// Empty token should fail
|
||||
if ValidateToken(ctx, nil, "") {
|
||||
t.Error("Expected empty token to fail validation")
|
||||
}
|
||||
@ -72,14 +64,13 @@ func TestValidateToken(t *testing.T) {
|
||||
func TestValidateTokenNoSession(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
// No session should fail validation
|
||||
if ValidateToken(ctx, nil, "any-token") {
|
||||
t.Error("Expected validation to fail with no session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHiddenField(t *testing.T) {
|
||||
session := &auth.Session{
|
||||
sess := &session.Session{
|
||||
ID: "test-session",
|
||||
UserID: 1,
|
||||
Username: "testuser",
|
||||
@ -88,7 +79,7 @@ func TestHiddenField(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue(SessionCtxKey, session)
|
||||
ctx.SetUserValue(SessionCtxKey, sess)
|
||||
|
||||
field := HiddenField(ctx, nil)
|
||||
expected := `<input type="hidden" name="_csrf_token" value="test-token">`
|
||||
@ -102,13 +93,13 @@ func TestHiddenFieldNoSession(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
field := HiddenField(ctx, nil)
|
||||
if field != "" {
|
||||
t.Errorf("Expected empty field with no session, got %s", field)
|
||||
if field == "" {
|
||||
t.Error("Expected non-empty field for guest user with cookie-based token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenMeta(t *testing.T) {
|
||||
session := &auth.Session{
|
||||
sess := &session.Session{
|
||||
ID: "test-session",
|
||||
UserID: 1,
|
||||
Username: "testuser",
|
||||
@ -117,7 +108,7 @@ func TestTokenMeta(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue(SessionCtxKey, session)
|
||||
ctx.SetUserValue(SessionCtxKey, sess)
|
||||
|
||||
meta := TokenMeta(ctx, nil)
|
||||
expected := `<meta name="csrf-token" content="test-token">`
|
||||
@ -128,30 +119,30 @@ func TestTokenMeta(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStoreAndGetToken(t *testing.T) {
|
||||
session := &auth.Session{
|
||||
sess := &session.Session{
|
||||
Data: make(map[string]any),
|
||||
}
|
||||
|
||||
token := "test-token"
|
||||
StoreToken(session, token)
|
||||
StoreToken(sess, token)
|
||||
|
||||
retrieved := GetStoredToken(session)
|
||||
retrieved := GetStoredToken(sess)
|
||||
if retrieved != token {
|
||||
t.Errorf("Expected %s, got %s", token, retrieved)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStoredTokenNoData(t *testing.T) {
|
||||
session := &auth.Session{}
|
||||
sess := &session.Session{}
|
||||
|
||||
token := GetStoredToken(session)
|
||||
token := GetStoredToken(sess)
|
||||
if token != "" {
|
||||
t.Errorf("Expected empty token, got %s", token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFormToken(t *testing.T) {
|
||||
session := &auth.Session{
|
||||
sess := &session.Session{
|
||||
ID: "test-session",
|
||||
UserID: 1,
|
||||
Username: "testuser",
|
||||
@ -160,16 +151,14 @@ func TestValidateFormToken(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue(SessionCtxKey, session)
|
||||
ctx.SetUserValue(SessionCtxKey, sess)
|
||||
|
||||
// Set form data
|
||||
ctx.PostArgs().Set(TokenFieldName, "test-token")
|
||||
|
||||
if !ValidateFormToken(ctx, nil) {
|
||||
t.Error("Expected form token validation to pass")
|
||||
}
|
||||
|
||||
// Test with wrong token
|
||||
ctx.PostArgs().Set(TokenFieldName, "wrong-token")
|
||||
|
||||
if ValidateFormToken(ctx, nil) {
|
||||
|
@ -1,29 +0,0 @@
|
||||
// Package csrf provides Cross-Site Request Forgery (CSRF) protection
|
||||
// with session-based token storage and form helpers.
|
||||
//
|
||||
// # Basic Usage
|
||||
//
|
||||
// // Generate token and store in session
|
||||
// token := csrf.GenerateToken(ctx, authManager)
|
||||
//
|
||||
// // In templates - generate hidden input field
|
||||
// hiddenField := csrf.HiddenField(ctx, authManager)
|
||||
//
|
||||
// // Verify form submission
|
||||
// if !csrf.ValidateToken(ctx, authManager, formToken) {
|
||||
// // Handle CSRF validation failure
|
||||
// }
|
||||
//
|
||||
// # Middleware Integration
|
||||
//
|
||||
// // Add CSRF middleware to protected routes
|
||||
// r.Use(middleware.CSRF(authManager))
|
||||
//
|
||||
// # Security Features
|
||||
//
|
||||
// - Cryptographically secure token generation
|
||||
// - Session-based token storage and validation
|
||||
// - Automatic token rotation on successful validation
|
||||
// - Protection against timing attacks with constant-time comparison
|
||||
// - Integration with existing authentication system
|
||||
package csrf
|
@ -3,30 +3,26 @@ package middleware
|
||||
import (
|
||||
"dk/internal/auth"
|
||||
"dk/internal/router"
|
||||
"dk/internal/session"
|
||||
"dk/internal/users"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// Auth creates an authentication middleware
|
||||
func Auth(authManager *auth.AuthManager) router.Middleware {
|
||||
return func(next router.Handler) router.Handler {
|
||||
return func(ctx router.Ctx, params []string) {
|
||||
sessionID := auth.GetSessionCookie(ctx)
|
||||
|
||||
if sessionID != "" {
|
||||
if session, exists := authManager.GetSession(sessionID); exists {
|
||||
// Update session activity
|
||||
if sess, exists := authManager.GetSession(sessionID); exists {
|
||||
authManager.UpdateSession(sessionID)
|
||||
|
||||
// Get the full user object
|
||||
user, err := users.Find(session.UserID)
|
||||
user, err := users.Find(sess.UserID)
|
||||
if err == nil && user != nil {
|
||||
// Store session and user info in context
|
||||
ctx.SetUserValue("session", session)
|
||||
ctx.SetUserValue("session", sess)
|
||||
ctx.SetUserValue("user", user)
|
||||
|
||||
// Refresh the cookie
|
||||
auth.SetSessionCookie(ctx, sessionID)
|
||||
}
|
||||
}
|
||||
@ -37,7 +33,6 @@ func Auth(authManager *auth.AuthManager) router.Middleware {
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAuth enforces authentication - redirect defaults to "/login"
|
||||
func RequireAuth(paths ...string) router.Middleware {
|
||||
redirect := "/login"
|
||||
if len(paths) > 0 && paths[0] != "" {
|
||||
@ -56,7 +51,6 @@ func RequireAuth(paths ...string) router.Middleware {
|
||||
}
|
||||
}
|
||||
|
||||
// RequireGuest enforces no authentication - redirect defaults to "/"
|
||||
func RequireGuest(paths ...string) router.Middleware {
|
||||
redirect := "/"
|
||||
if len(paths) > 0 && paths[0] != "" {
|
||||
@ -74,13 +68,11 @@ func RequireGuest(paths ...string) router.Middleware {
|
||||
}
|
||||
}
|
||||
|
||||
// IsAuthenticated checks if the current request has a valid session
|
||||
func IsAuthenticated(ctx router.Ctx) bool {
|
||||
_, exists := ctx.UserValue("user").(*users.User)
|
||||
return exists
|
||||
}
|
||||
|
||||
// GetCurrentUser returns the current authenticated user, or nil if not authenticated
|
||||
func GetCurrentUser(ctx router.Ctx) *users.User {
|
||||
if user, ok := ctx.UserValue("user").(*users.User); ok {
|
||||
return user
|
||||
@ -88,25 +80,21 @@ func GetCurrentUser(ctx router.Ctx) *users.User {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentSession returns the current session, or nil if not authenticated
|
||||
func GetCurrentSession(ctx router.Ctx) *auth.Session {
|
||||
if session, ok := ctx.UserValue("session").(*auth.Session); ok {
|
||||
return session
|
||||
func GetCurrentSession(ctx router.Ctx) *session.Session {
|
||||
if sess, ok := ctx.UserValue("session").(*session.Session); ok {
|
||||
return sess
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Login creates a session and sets the cookie
|
||||
func Login(ctx router.Ctx, authManager *auth.AuthManager, user *users.User) {
|
||||
session := authManager.CreateSession(user)
|
||||
auth.SetSessionCookie(ctx, session.ID)
|
||||
sess := authManager.CreateSession(user)
|
||||
auth.SetSessionCookie(ctx, sess.ID)
|
||||
|
||||
// Set in context for immediate use
|
||||
ctx.SetUserValue("session", session)
|
||||
ctx.SetUserValue("session", sess)
|
||||
ctx.SetUserValue("user", user)
|
||||
}
|
||||
|
||||
// Logout destroys the session and clears the cookie
|
||||
func Logout(ctx router.Ctx, authManager *auth.AuthManager) {
|
||||
sessionID := auth.GetSessionCookie(ctx)
|
||||
if sessionID != "" {
|
||||
@ -115,7 +103,6 @@ func Logout(ctx router.Ctx, authManager *auth.AuthManager) {
|
||||
|
||||
auth.DeleteSessionCookie(ctx)
|
||||
|
||||
// Clear from context
|
||||
ctx.SetUserValue("session", nil)
|
||||
ctx.SetUserValue("user", nil)
|
||||
}
|
56
internal/session/flash.go
Normal file
56
internal/session/flash.go
Normal file
@ -0,0 +1,56 @@
|
||||
package session
|
||||
|
||||
type FlashMessage struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (s *Session) SetFlash(key string, value any) {
|
||||
if s.Data == nil {
|
||||
s.Data = make(map[string]any)
|
||||
}
|
||||
|
||||
flashData, ok := s.Data["_flash"].(map[string]any)
|
||||
if !ok {
|
||||
flashData = make(map[string]any)
|
||||
}
|
||||
flashData[key] = value
|
||||
s.Data["_flash"] = flashData
|
||||
}
|
||||
|
||||
func (s *Session) GetFlash(key string) (any, bool) {
|
||||
if s.Data == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
flashData, ok := s.Data["_flash"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
value, exists := flashData[key]
|
||||
if exists {
|
||||
delete(flashData, key)
|
||||
if len(flashData) == 0 {
|
||||
delete(s.Data, "_flash")
|
||||
} else {
|
||||
s.Data["_flash"] = flashData
|
||||
}
|
||||
}
|
||||
|
||||
return value, exists
|
||||
}
|
||||
|
||||
func (s *Session) GetAllFlash() map[string]any {
|
||||
if s.Data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
flashData, ok := s.Data["_flash"].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
delete(s.Data, "_flash")
|
||||
return flashData
|
||||
}
|
74
internal/session/session.go
Normal file
74
internal/session/session.go
Normal file
@ -0,0 +1,74 @@
|
||||
// Package session provides session management functionality.
|
||||
// It includes session storage, flash messages, and data persistence.
|
||||
package session
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultExpiration = 24 * time.Hour
|
||||
IDLength = 32
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
ID string `json:"-"`
|
||||
UserID int `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
LastSeen time.Time `json:"last_seen"`
|
||||
Data map[string]any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
func New(userID int, username, email string) *Session {
|
||||
return &Session{
|
||||
ID: generateID(),
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
Email: email,
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(DefaultExpiration),
|
||||
LastSeen: time.Now(),
|
||||
Data: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) IsExpired() bool {
|
||||
return time.Now().After(s.ExpiresAt)
|
||||
}
|
||||
|
||||
func (s *Session) Touch() {
|
||||
s.LastSeen = time.Now()
|
||||
s.ExpiresAt = time.Now().Add(DefaultExpiration)
|
||||
}
|
||||
|
||||
func (s *Session) Set(key string, value any) {
|
||||
if s.Data == nil {
|
||||
s.Data = make(map[string]any)
|
||||
}
|
||||
s.Data[key] = value
|
||||
}
|
||||
|
||||
func (s *Session) Get(key string) (any, bool) {
|
||||
if s.Data == nil {
|
||||
return nil, false
|
||||
}
|
||||
value, exists := s.Data[key]
|
||||
return value, exists
|
||||
}
|
||||
|
||||
func (s *Session) Delete(key string) {
|
||||
if s.Data != nil {
|
||||
delete(s.Data, key)
|
||||
}
|
||||
}
|
||||
|
||||
func generateID() string {
|
||||
bytes := make([]byte, IDLength)
|
||||
rand.Read(bytes)
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
161
internal/session/store.go
Normal file
161
internal/session/store.go
Normal file
@ -0,0 +1,161 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"maps"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*Session
|
||||
filePath string
|
||||
saveInterval time.Duration
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
type persistedData struct {
|
||||
Sessions map[string]*Session `json:"sessions"`
|
||||
SavedAt time.Time `json:"saved_at"`
|
||||
}
|
||||
|
||||
func NewStore(filePath string) *Store {
|
||||
store := &Store{
|
||||
sessions: make(map[string]*Session),
|
||||
filePath: filePath,
|
||||
saveInterval: 5 * time.Minute,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
store.loadFromFile()
|
||||
store.startPeriodicSave()
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func (s *Store) Save(session *Session) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sessions[session.ID] = session
|
||||
}
|
||||
|
||||
func (s *Store) Get(sessionID string) (*Session, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
session, exists := s.sessions[sessionID]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if session.IsExpired() {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return session, true
|
||||
}
|
||||
|
||||
func (s *Store) Delete(sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.sessions, sessionID)
|
||||
}
|
||||
|
||||
func (s *Store) Cleanup() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for id, session := range s.sessions {
|
||||
if session.IsExpired() {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) Stats() (total, active int) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
total = len(s.sessions)
|
||||
for _, session := range s.sessions {
|
||||
if !session.IsExpired() {
|
||||
active++
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Store) loadFromFile() {
|
||||
if s.filePath == "" {
|
||||
return
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(s.filePath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var persisted persistedData
|
||||
if err := json.Unmarshal(data, &persisted); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for id, session := range persisted.Sessions {
|
||||
if !session.IsExpired() {
|
||||
session.ID = id
|
||||
s.sessions[id] = session
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) saveToFile() error {
|
||||
if s.filePath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
sessionsCopy := make(map[string]*Session, len(s.sessions))
|
||||
maps.Copy(sessionsCopy, s.sessions)
|
||||
s.mu.RUnlock()
|
||||
|
||||
data := persistedData{
|
||||
Sessions: sessionsCopy,
|
||||
SavedAt: time.Now(),
|
||||
}
|
||||
|
||||
jsonData, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(s.filePath, jsonData, 0600)
|
||||
}
|
||||
|
||||
func (s *Store) startPeriodicSave() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(s.saveInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.Cleanup()
|
||||
s.saveToFile()
|
||||
case <-s.stopChan:
|
||||
s.saveToFile()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Store) Close() error {
|
||||
close(s.stopChan)
|
||||
return s.saveToFile()
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user