diff --git a/internal/auth/auth.go b/internal/auth/auth.go new file mode 100644 index 0000000..9281202 --- /dev/null +++ b/internal/auth/auth.go @@ -0,0 +1,93 @@ +package auth + +import ( + "dk/internal/database" + "dk/internal/password" + "dk/internal/users" +) + +type User struct { + ID int + Username string + Email string +} + +type AuthManager struct { + sessionStore *SessionStore + db *database.DB +} + +func NewAuthManager(db *database.DB, sessionsFilePath string) *AuthManager { + return &AuthManager{ + sessionStore: NewSessionStore(sessionsFilePath), + db: db, + } +} + +func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*User, error) { + var user *users.User + var err error + + // Try to find user by username first + user, err = users.GetByUsername(am.db, usernameOrEmail) + if err != nil { + // Try by email if username lookup failed + user, err = users.GetByEmail(am.db, usernameOrEmail) + if err != nil { + return nil, err + } + } + + // Verify password + isValid, err := password.Verify(user.Password, plainPassword) + if err != nil { + return nil, err + } + if !isValid { + return nil, ErrInvalidCredentials + } + + return &User{ + ID: user.ID, + Username: user.Username, + Email: user.Email, + }, nil +} + +func (am *AuthManager) CreateSession(user *User) *Session { + return am.sessionStore.Create(user.ID, user.Username, user.Email) +} + +func (am *AuthManager) GetSession(sessionID string) (*Session, bool) { + return am.sessionStore.Get(sessionID) +} + +func (am *AuthManager) UpdateSession(sessionID string) bool { + return am.sessionStore.Update(sessionID) +} + +func (am *AuthManager) DeleteSession(sessionID string) { + am.sessionStore.Delete(sessionID) +} + +func (am *AuthManager) SessionStats() (total, active int) { + return am.sessionStore.Stats() +} + +func (am *AuthManager) Close() error { + return am.sessionStore.Close() +} + +var ( + ErrInvalidCredentials = &AuthError{"invalid username/email or password"} + ErrSessionNotFound = &AuthError{"session not found"} + ErrSessionExpired = &AuthError{"session expired"} +) + +type AuthError struct { + Message string +} + +func (e *AuthError) Error() string { + return e.Message +} \ No newline at end of file diff --git a/internal/auth/cookies.go b/internal/auth/cookies.go new file mode 100644 index 0000000..437b624 --- /dev/null +++ b/internal/auth/cookies.go @@ -0,0 +1,103 @@ +package auth + +import ( + "time" + + "github.com/valyala/fasthttp" +) + +type CookieOptions struct { + Name string + Value string + Path string + Domain string + Expires time.Time + MaxAge int + Secure bool + HTTPOnly bool + SameSite string +} + +func SetSecureCookie(ctx *fasthttp.RequestCtx, opts CookieOptions) { + cookie := &fasthttp.Cookie{} + + cookie.SetKey(opts.Name) + cookie.SetValue(opts.Value) + + if opts.Path != "" { + cookie.SetPath(opts.Path) + } else { + cookie.SetPath("/") + } + + if opts.Domain != "" { + cookie.SetDomain(opts.Domain) + } + + if !opts.Expires.IsZero() { + cookie.SetExpire(opts.Expires) + } + + if opts.MaxAge > 0 { + cookie.SetMaxAge(opts.MaxAge) + } + + cookie.SetSecure(opts.Secure) + cookie.SetHTTPOnly(opts.HTTPOnly) + + switch opts.SameSite { + case "strict": + cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode) + case "lax": + cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode) + case "none": + cookie.SetSameSite(fasthttp.CookieSameSiteNoneMode) + default: + cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode) + } + + ctx.Response.Header.SetCookie(cookie) +} + +func GetCookie(ctx *fasthttp.RequestCtx, name string) string { + return string(ctx.Request.Header.Cookie(name)) +} + +func DeleteCookie(ctx *fasthttp.RequestCtx, name string) { + SetSecureCookie(ctx, CookieOptions{ + Name: name, + Value: "", + Path: "/", + Expires: time.Unix(0, 0), + MaxAge: -1, + HTTPOnly: true, + Secure: true, + SameSite: "lax", + }) +} + +func SetSessionCookie(ctx *fasthttp.RequestCtx, sessionID string) { + SetSecureCookie(ctx, CookieOptions{ + Name: SessionCookieName, + Value: sessionID, + Path: "/", + Expires: time.Now().Add(DefaultExpiration), + HTTPOnly: true, + Secure: isHTTPS(ctx), + SameSite: "lax", + }) +} + +func GetSessionCookie(ctx *fasthttp.RequestCtx) string { + return GetCookie(ctx, SessionCookieName) +} + +func DeleteSessionCookie(ctx *fasthttp.RequestCtx) { + DeleteCookie(ctx, SessionCookieName) +} + +func isHTTPS(ctx *fasthttp.RequestCtx) bool { + return ctx.IsTLS() || + string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" || + string(ctx.Request.Header.Peek("X-Forwarded-Scheme")) == "https" +} \ No newline at end of file diff --git a/internal/auth/doc.go b/internal/auth/doc.go new file mode 100644 index 0000000..48cd94f --- /dev/null +++ b/internal/auth/doc.go @@ -0,0 +1,4 @@ +// Package auth provides authentication and session management functionality. +// It includes secure session storage with in-memory caching and JSON persistence, +// user authentication against the database, and secure cookie handling. +package auth \ No newline at end of file diff --git a/internal/auth/session.go b/internal/auth/session.go new file mode 100644 index 0000000..38203eb --- /dev/null +++ b/internal/auth/session.go @@ -0,0 +1,221 @@ +package auth + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "maps" + "os" + "sync" + "time" +) + +const ( + SessionCookieName = "dk_session" + DefaultExpiration = 24 * time.Hour + SessionIDLength = 32 +) + +type Session struct { + ID string `json:"id"` + UserID int `json:"user_id"` + Username string `json:"username"` + Email string `json:"email"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` + LastSeen time.Time `json:"last_seen"` +} + +type SessionStore struct { + mu sync.RWMutex + sessions map[string]*Session + filePath string + saveInterval time.Duration + stopChan chan struct{} +} + +type persistedData struct { + Sessions map[string]*Session `json:"sessions"` + SavedAt time.Time `json:"saved_at"` +} + +func NewSessionStore(filePath string) *SessionStore { + store := &SessionStore{ + sessions: make(map[string]*Session), + filePath: filePath, + saveInterval: 5 * time.Minute, + stopChan: make(chan struct{}), + } + + store.loadFromFile() + store.startPeriodicSave() + + return store +} + +func (s *SessionStore) generateSessionID() string { + bytes := make([]byte, SessionIDLength) + rand.Read(bytes) + return hex.EncodeToString(bytes) +} + +func (s *SessionStore) Create(userID int, username, email string) *Session { + s.mu.Lock() + defer s.mu.Unlock() + + session := &Session{ + ID: s.generateSessionID(), + UserID: userID, + Username: username, + Email: email, + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(DefaultExpiration), + LastSeen: time.Now(), + } + + s.sessions[session.ID] = session + return session +} + +func (s *SessionStore) Get(sessionID string) (*Session, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + session, exists := s.sessions[sessionID] + if !exists { + return nil, false + } + + if time.Now().After(session.ExpiresAt) { + delete(s.sessions, sessionID) + return nil, false + } + + return session, true +} + +func (s *SessionStore) Update(sessionID string) bool { + s.mu.Lock() + defer s.mu.Unlock() + + session, exists := s.sessions[sessionID] + if !exists { + return false + } + + if time.Now().After(session.ExpiresAt) { + delete(s.sessions, sessionID) + return false + } + + session.LastSeen = time.Now() + session.ExpiresAt = time.Now().Add(DefaultExpiration) + return true +} + +func (s *SessionStore) Delete(sessionID string) { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.sessions, sessionID) +} + +func (s *SessionStore) Cleanup() { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + for id, session := range s.sessions { + if now.After(session.ExpiresAt) { + delete(s.sessions, id) + } + } +} + +func (s *SessionStore) loadFromFile() { + if s.filePath == "" { + return + } + + data, err := os.ReadFile(s.filePath) + if err != nil { + return // File might not exist yet + } + + var persisted persistedData + if err := json.Unmarshal(data, &persisted); err != nil { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + for id, session := range persisted.Sessions { + if now.Before(session.ExpiresAt) { + s.sessions[id] = session + } + } +} + +func (s *SessionStore) saveToFile() error { + if s.filePath == "" { + return nil + } + + s.mu.RLock() + sessionsCopy := make(map[string]*Session) + maps.Copy(sessionsCopy, s.sessions) + s.mu.RUnlock() + + data := persistedData{ + Sessions: sessionsCopy, + SavedAt: time.Now(), + } + + jsonData, err := json.MarshalIndent(data, "", " ") + if err != nil { + return err + } + + return os.WriteFile(s.filePath, jsonData, 0600) +} + +func (s *SessionStore) startPeriodicSave() { + go func() { + ticker := time.NewTicker(s.saveInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.Cleanup() + s.saveToFile() + case <-s.stopChan: + s.saveToFile() + return + } + } + }() +} + +func (s *SessionStore) Close() error { + close(s.stopChan) + return s.saveToFile() +} + +func (s *SessionStore) Stats() (total, active int) { + s.mu.RLock() + defer s.mu.RUnlock() + + now := time.Now() + total = len(s.sessions) + + for _, session := range s.sessions { + if now.Before(session.ExpiresAt) { + active++ + } + } + + return +} diff --git a/internal/database/database.go b/internal/database/database.go index 7332788..66bc318 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -57,6 +57,16 @@ func (db *DB) Close() error { return db.pool.Close() } +// GetConn gets a connection from the pool - caller must call Put when done +func (db *DB) GetConn(ctx context.Context) (*sqlite.Conn, error) { + return db.pool.Take(ctx) +} + +// PutConn returns a connection to the pool +func (db *DB) PutConn(conn *sqlite.Conn) { + db.pool.Put(conn) +} + // Exec executes a SQL statement without returning results func (db *DB) Exec(query string, args ...any) error { conn, err := db.pool.Take(context.Background()) diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go new file mode 100644 index 0000000..1ab0321 --- /dev/null +++ b/internal/middleware/auth.go @@ -0,0 +1,116 @@ +package middleware + +import ( + "dk/internal/auth" + "dk/internal/router" + + "github.com/valyala/fasthttp" +) + +const ( + UserKey = "user" + SessionKey = "session" +) + +// Auth creates an authentication middleware +func Auth(authManager *auth.AuthManager) router.Middleware { + return func(next router.Handler) router.Handler { + return func(ctx router.Ctx, params []string) { + sessionID := auth.GetSessionCookie(ctx) + + if sessionID != "" { + if session, exists := authManager.GetSession(sessionID); exists { + // Update session activity + authManager.UpdateSession(sessionID) + + // Store session and user info in context + ctx.SetUserValue(SessionKey, session) + ctx.SetUserValue(UserKey, &auth.User{ + ID: session.UserID, + Username: session.Username, + Email: session.Email, + }) + + // Refresh the cookie + auth.SetSessionCookie(ctx, sessionID) + } + } + + next(ctx, params) + } + } +} + +// RequireAuth enforces authentication - redirects to login if not authenticated +func RequireAuth(loginPath string) router.Middleware { + return func(next router.Handler) router.Handler { + return func(ctx router.Ctx, params []string) { + if !IsAuthenticated(ctx) { + ctx.Redirect(loginPath, fasthttp.StatusFound) + return + } + + next(ctx, params) + } + } +} + +// RequireGuest enforces no authentication - redirects to dashboard if authenticated +func RequireGuest(dashboardPath string) router.Middleware { + return func(next router.Handler) router.Handler { + return func(ctx router.Ctx, params []string) { + if IsAuthenticated(ctx) { + ctx.Redirect(dashboardPath, fasthttp.StatusFound) + return + } + + next(ctx, params) + } + } +} + +// IsAuthenticated checks if the current request has a valid session +func IsAuthenticated(ctx router.Ctx) bool { + _, exists := ctx.UserValue(UserKey).(*auth.User) + return exists +} + +// GetCurrentUser returns the current authenticated user, or nil if not authenticated +func GetCurrentUser(ctx router.Ctx) *auth.User { + if user, ok := ctx.UserValue(UserKey).(*auth.User); ok { + return user + } + return nil +} + +// GetCurrentSession returns the current session, or nil if not authenticated +func GetCurrentSession(ctx router.Ctx) *auth.Session { + if session, ok := ctx.UserValue(SessionKey).(*auth.Session); ok { + return session + } + return nil +} + +// Login creates a session and sets the cookie +func Login(ctx router.Ctx, authManager *auth.AuthManager, user *auth.User) { + session := authManager.CreateSession(user) + auth.SetSessionCookie(ctx, session.ID) + + // Set in context for immediate use + ctx.SetUserValue(SessionKey, session) + ctx.SetUserValue(UserKey, user) +} + +// Logout destroys the session and clears the cookie +func Logout(ctx router.Ctx, authManager *auth.AuthManager) { + sessionID := auth.GetSessionCookie(ctx) + if sessionID != "" { + authManager.DeleteSession(sessionID) + } + + auth.DeleteSessionCookie(ctx) + + // Clear from context + ctx.SetUserValue(SessionKey, nil) + ctx.SetUserValue(UserKey, nil) +} \ No newline at end of file diff --git a/internal/server/server.go b/internal/server/server.go index 2f28250..1679ecf 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -4,8 +4,12 @@ import ( "fmt" "log" "os" + "os/signal" "path/filepath" + "syscall" + "dk/internal/auth" + "dk/internal/database" "dk/internal/middleware" "dk/internal/router" "dk/internal/template" @@ -21,11 +25,23 @@ func Start(port string) error { } templateCache := template.NewCache(cwd) + // Initialize database + db, err := database.Open("dk.sqlite") + if err != nil { + return fmt.Errorf("failed to open database: %w", err) + } + defer db.Close() + + // Initialize authentication manager + authManager := auth.NewAuthManager(db, "sessions.json") + defer authManager.Close() + // Initialize router r := router.New() - // Add timing middleware + // Add middleware r.Use(middleware.Timing()) + r.Use(middleware.Auth(authManager)) // Hello world endpoint r.Get("/", func(ctx router.Ctx, params []string) { @@ -36,13 +52,28 @@ func Start(port string) error { return } + // Get current user if authenticated + currentUser := middleware.GetCurrentUser(ctx) + var username string + if currentUser != nil { + username = currentUser.Username + } else { + username = "Guest" + } + + totalSessions, activeSessions := authManager.SessionStats() + data := map[string]any{ - "title": "Dragon Knight", - "content": "Hello World!", - "totaltime": middleware.GetRequestTime(ctx), - "numqueries": "0", // Placeholder for now - "version": "1.0.0", - "build": "dev", + "title": "Dragon Knight", + "content": fmt.Sprintf("Hello %s!", username), + "totaltime": middleware.GetRequestTime(ctx), + "numqueries": "0", // Placeholder for now + "version": "1.0.0", + "build": "dev", + "total_sessions": totalSessions, + "active_sessions": activeSessions, + "authenticated": currentUser != nil, + "username": username, } tmpl.WriteTo(ctx, data) @@ -78,5 +109,32 @@ func Start(port string) error { addr := ":" + port log.Printf("Server starting on %s", addr) - return fasthttp.ListenAndServe(addr, requestHandler) + + // Setup graceful shutdown + server := &fasthttp.Server{ + Handler: requestHandler, + } + + // Channel to listen for interrupt signal + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + + // Start server in a goroutine + go func() { + if err := server.ListenAndServe(addr); err != nil { + log.Printf("Server error: %v", err) + } + }() + + // Block until we receive a signal + <-c + log.Println("Shutting down server...") + + // Shutdown server gracefully + if err := server.Shutdown(); err != nil { + log.Printf("Server shutdown error: %v", err) + } + + log.Println("Server stopped") + return nil } diff --git a/internal/users/users.go b/internal/users/users.go index 8c14579..859e2bb 100644 --- a/internal/users/users.go +++ b/internal/users/users.go @@ -422,3 +422,47 @@ func (u *User) SetPosition(x, y int) { u.X = x u.Y = y } + +// GetByUsername retrieves a user by username +func GetByUsername(db *database.DB, username string) (*User, error) { + var user *User + + query := `SELECT ` + userColumns() + ` FROM users WHERE LOWER(username) = LOWER(?) LIMIT 1` + + err := db.Query(query, func(stmt *sqlite.Stmt) error { + user = scanUser(stmt, db) + return nil + }, username) + + if err != nil { + return nil, fmt.Errorf("query failed: %w", err) + } + + if user == nil { + return nil, fmt.Errorf("user not found: %s", username) + } + + return user, nil +} + +// GetByEmail retrieves a user by email +func GetByEmail(db *database.DB, email string) (*User, error) { + var user *User + + query := `SELECT ` + userColumns() + ` FROM users WHERE LOWER(email) = LOWER(?) LIMIT 1` + + err := db.Query(query, func(stmt *sqlite.Stmt) error { + user = scanUser(stmt, db) + return nil + }, email) + + if err != nil { + return nil, fmt.Errorf("query failed: %w", err) + } + + if user == nil { + return nil, fmt.Errorf("user not found: %s", email) + } + + return user, nil +} diff --git a/templates/auth/login.html b/templates/auth/login.html new file mode 100644 index 0000000..a5ec071 --- /dev/null +++ b/templates/auth/login.html @@ -0,0 +1,37 @@ +{flashhtml} + +
+ {csrf} + + + + + + + + + + + + + + + + + + + +
Username:
Password:
Remember me? Yes
+ Checking the "Remember Me" option will store your login information in a cookie + so you don't have to enter it next time you get online. + +

+ + Want to play? You gotta register your own character. + +

+ + You may also change your password, or + request a new one if you've lost yours. +
+
diff --git a/templates/auth/register.html b/templates/auth/register.html new file mode 100644 index 0000000..8396d81 --- /dev/null +++ b/templates/auth/register.html @@ -0,0 +1,60 @@ +{flashhtml} + +
+ {csrf} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Username: + +
+ Usernames must be 30 alphanumeric characters or less. +


+
Password:
Verify Password: + +
+ Passwords must be 10 alphanumeric characters or less. +


+
Email Address:
Verify Email: + + {verifytext} +


+
Character Class: + +
See Help for more information about character classes.

+ + +
+