auth system first pass

This commit is contained in:
Sky Johnson 2025-08-09 09:44:44 -05:00
parent a1e8d49c0e
commit a49346160b
10 changed files with 754 additions and 8 deletions

93
internal/auth/auth.go Normal file
View File

@ -0,0 +1,93 @@
package auth
import (
"dk/internal/database"
"dk/internal/password"
"dk/internal/users"
)
type User struct {
ID int
Username string
Email string
}
type AuthManager struct {
sessionStore *SessionStore
db *database.DB
}
func NewAuthManager(db *database.DB, sessionsFilePath string) *AuthManager {
return &AuthManager{
sessionStore: NewSessionStore(sessionsFilePath),
db: db,
}
}
func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*User, error) {
var user *users.User
var err error
// Try to find user by username first
user, err = users.GetByUsername(am.db, usernameOrEmail)
if err != nil {
// Try by email if username lookup failed
user, err = users.GetByEmail(am.db, usernameOrEmail)
if err != nil {
return nil, err
}
}
// Verify password
isValid, err := password.Verify(user.Password, plainPassword)
if err != nil {
return nil, err
}
if !isValid {
return nil, ErrInvalidCredentials
}
return &User{
ID: user.ID,
Username: user.Username,
Email: user.Email,
}, nil
}
func (am *AuthManager) CreateSession(user *User) *Session {
return am.sessionStore.Create(user.ID, user.Username, user.Email)
}
func (am *AuthManager) GetSession(sessionID string) (*Session, bool) {
return am.sessionStore.Get(sessionID)
}
func (am *AuthManager) UpdateSession(sessionID string) bool {
return am.sessionStore.Update(sessionID)
}
func (am *AuthManager) DeleteSession(sessionID string) {
am.sessionStore.Delete(sessionID)
}
func (am *AuthManager) SessionStats() (total, active int) {
return am.sessionStore.Stats()
}
func (am *AuthManager) Close() error {
return am.sessionStore.Close()
}
var (
ErrInvalidCredentials = &AuthError{"invalid username/email or password"}
ErrSessionNotFound = &AuthError{"session not found"}
ErrSessionExpired = &AuthError{"session expired"}
)
type AuthError struct {
Message string
}
func (e *AuthError) Error() string {
return e.Message
}

103
internal/auth/cookies.go Normal file
View File

@ -0,0 +1,103 @@
package auth
import (
"time"
"github.com/valyala/fasthttp"
)
type CookieOptions struct {
Name string
Value string
Path string
Domain string
Expires time.Time
MaxAge int
Secure bool
HTTPOnly bool
SameSite string
}
func SetSecureCookie(ctx *fasthttp.RequestCtx, opts CookieOptions) {
cookie := &fasthttp.Cookie{}
cookie.SetKey(opts.Name)
cookie.SetValue(opts.Value)
if opts.Path != "" {
cookie.SetPath(opts.Path)
} else {
cookie.SetPath("/")
}
if opts.Domain != "" {
cookie.SetDomain(opts.Domain)
}
if !opts.Expires.IsZero() {
cookie.SetExpire(opts.Expires)
}
if opts.MaxAge > 0 {
cookie.SetMaxAge(opts.MaxAge)
}
cookie.SetSecure(opts.Secure)
cookie.SetHTTPOnly(opts.HTTPOnly)
switch opts.SameSite {
case "strict":
cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
case "lax":
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
case "none":
cookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
default:
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
}
ctx.Response.Header.SetCookie(cookie)
}
func GetCookie(ctx *fasthttp.RequestCtx, name string) string {
return string(ctx.Request.Header.Cookie(name))
}
func DeleteCookie(ctx *fasthttp.RequestCtx, name string) {
SetSecureCookie(ctx, CookieOptions{
Name: name,
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
MaxAge: -1,
HTTPOnly: true,
Secure: true,
SameSite: "lax",
})
}
func SetSessionCookie(ctx *fasthttp.RequestCtx, sessionID string) {
SetSecureCookie(ctx, CookieOptions{
Name: SessionCookieName,
Value: sessionID,
Path: "/",
Expires: time.Now().Add(DefaultExpiration),
HTTPOnly: true,
Secure: isHTTPS(ctx),
SameSite: "lax",
})
}
func GetSessionCookie(ctx *fasthttp.RequestCtx) string {
return GetCookie(ctx, SessionCookieName)
}
func DeleteSessionCookie(ctx *fasthttp.RequestCtx) {
DeleteCookie(ctx, SessionCookieName)
}
func isHTTPS(ctx *fasthttp.RequestCtx) bool {
return ctx.IsTLS() ||
string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" ||
string(ctx.Request.Header.Peek("X-Forwarded-Scheme")) == "https"
}

4
internal/auth/doc.go Normal file
View File

@ -0,0 +1,4 @@
// Package auth provides authentication and session management functionality.
// It includes secure session storage with in-memory caching and JSON persistence,
// user authentication against the database, and secure cookie handling.
package auth

221
internal/auth/session.go Normal file
View File

@ -0,0 +1,221 @@
package auth
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"maps"
"os"
"sync"
"time"
)
const (
SessionCookieName = "dk_session"
DefaultExpiration = 24 * time.Hour
SessionIDLength = 32
)
type Session struct {
ID string `json:"id"`
UserID int `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
LastSeen time.Time `json:"last_seen"`
}
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*Session
filePath string
saveInterval time.Duration
stopChan chan struct{}
}
type persistedData struct {
Sessions map[string]*Session `json:"sessions"`
SavedAt time.Time `json:"saved_at"`
}
func NewSessionStore(filePath string) *SessionStore {
store := &SessionStore{
sessions: make(map[string]*Session),
filePath: filePath,
saveInterval: 5 * time.Minute,
stopChan: make(chan struct{}),
}
store.loadFromFile()
store.startPeriodicSave()
return store
}
func (s *SessionStore) generateSessionID() string {
bytes := make([]byte, SessionIDLength)
rand.Read(bytes)
return hex.EncodeToString(bytes)
}
func (s *SessionStore) Create(userID int, username, email string) *Session {
s.mu.Lock()
defer s.mu.Unlock()
session := &Session{
ID: s.generateSessionID(),
UserID: userID,
Username: username,
Email: email,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(DefaultExpiration),
LastSeen: time.Now(),
}
s.sessions[session.ID] = session
return session
}
func (s *SessionStore) Get(sessionID string) (*Session, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, exists := s.sessions[sessionID]
if !exists {
return nil, false
}
if time.Now().After(session.ExpiresAt) {
delete(s.sessions, sessionID)
return nil, false
}
return session, true
}
func (s *SessionStore) Update(sessionID string) bool {
s.mu.Lock()
defer s.mu.Unlock()
session, exists := s.sessions[sessionID]
if !exists {
return false
}
if time.Now().After(session.ExpiresAt) {
delete(s.sessions, sessionID)
return false
}
session.LastSeen = time.Now()
session.ExpiresAt = time.Now().Add(DefaultExpiration)
return true
}
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
func (s *SessionStore) Cleanup() {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
for id, session := range s.sessions {
if now.After(session.ExpiresAt) {
delete(s.sessions, id)
}
}
}
func (s *SessionStore) loadFromFile() {
if s.filePath == "" {
return
}
data, err := os.ReadFile(s.filePath)
if err != nil {
return // File might not exist yet
}
var persisted persistedData
if err := json.Unmarshal(data, &persisted); err != nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
for id, session := range persisted.Sessions {
if now.Before(session.ExpiresAt) {
s.sessions[id] = session
}
}
}
func (s *SessionStore) saveToFile() error {
if s.filePath == "" {
return nil
}
s.mu.RLock()
sessionsCopy := make(map[string]*Session)
maps.Copy(sessionsCopy, s.sessions)
s.mu.RUnlock()
data := persistedData{
Sessions: sessionsCopy,
SavedAt: time.Now(),
}
jsonData, err := json.MarshalIndent(data, "", " ")
if err != nil {
return err
}
return os.WriteFile(s.filePath, jsonData, 0600)
}
func (s *SessionStore) startPeriodicSave() {
go func() {
ticker := time.NewTicker(s.saveInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.Cleanup()
s.saveToFile()
case <-s.stopChan:
s.saveToFile()
return
}
}
}()
}
func (s *SessionStore) Close() error {
close(s.stopChan)
return s.saveToFile()
}
func (s *SessionStore) Stats() (total, active int) {
s.mu.RLock()
defer s.mu.RUnlock()
now := time.Now()
total = len(s.sessions)
for _, session := range s.sessions {
if now.Before(session.ExpiresAt) {
active++
}
}
return
}

View File

@ -57,6 +57,16 @@ func (db *DB) Close() error {
return db.pool.Close()
}
// GetConn gets a connection from the pool - caller must call Put when done
func (db *DB) GetConn(ctx context.Context) (*sqlite.Conn, error) {
return db.pool.Take(ctx)
}
// PutConn returns a connection to the pool
func (db *DB) PutConn(conn *sqlite.Conn) {
db.pool.Put(conn)
}
// Exec executes a SQL statement without returning results
func (db *DB) Exec(query string, args ...any) error {
conn, err := db.pool.Take(context.Background())

116
internal/middleware/auth.go Normal file
View File

@ -0,0 +1,116 @@
package middleware
import (
"dk/internal/auth"
"dk/internal/router"
"github.com/valyala/fasthttp"
)
const (
UserKey = "user"
SessionKey = "session"
)
// Auth creates an authentication middleware
func Auth(authManager *auth.AuthManager) router.Middleware {
return func(next router.Handler) router.Handler {
return func(ctx router.Ctx, params []string) {
sessionID := auth.GetSessionCookie(ctx)
if sessionID != "" {
if session, exists := authManager.GetSession(sessionID); exists {
// Update session activity
authManager.UpdateSession(sessionID)
// Store session and user info in context
ctx.SetUserValue(SessionKey, session)
ctx.SetUserValue(UserKey, &auth.User{
ID: session.UserID,
Username: session.Username,
Email: session.Email,
})
// Refresh the cookie
auth.SetSessionCookie(ctx, sessionID)
}
}
next(ctx, params)
}
}
}
// RequireAuth enforces authentication - redirects to login if not authenticated
func RequireAuth(loginPath string) router.Middleware {
return func(next router.Handler) router.Handler {
return func(ctx router.Ctx, params []string) {
if !IsAuthenticated(ctx) {
ctx.Redirect(loginPath, fasthttp.StatusFound)
return
}
next(ctx, params)
}
}
}
// RequireGuest enforces no authentication - redirects to dashboard if authenticated
func RequireGuest(dashboardPath string) router.Middleware {
return func(next router.Handler) router.Handler {
return func(ctx router.Ctx, params []string) {
if IsAuthenticated(ctx) {
ctx.Redirect(dashboardPath, fasthttp.StatusFound)
return
}
next(ctx, params)
}
}
}
// IsAuthenticated checks if the current request has a valid session
func IsAuthenticated(ctx router.Ctx) bool {
_, exists := ctx.UserValue(UserKey).(*auth.User)
return exists
}
// GetCurrentUser returns the current authenticated user, or nil if not authenticated
func GetCurrentUser(ctx router.Ctx) *auth.User {
if user, ok := ctx.UserValue(UserKey).(*auth.User); ok {
return user
}
return nil
}
// GetCurrentSession returns the current session, or nil if not authenticated
func GetCurrentSession(ctx router.Ctx) *auth.Session {
if session, ok := ctx.UserValue(SessionKey).(*auth.Session); ok {
return session
}
return nil
}
// Login creates a session and sets the cookie
func Login(ctx router.Ctx, authManager *auth.AuthManager, user *auth.User) {
session := authManager.CreateSession(user)
auth.SetSessionCookie(ctx, session.ID)
// Set in context for immediate use
ctx.SetUserValue(SessionKey, session)
ctx.SetUserValue(UserKey, user)
}
// Logout destroys the session and clears the cookie
func Logout(ctx router.Ctx, authManager *auth.AuthManager) {
sessionID := auth.GetSessionCookie(ctx)
if sessionID != "" {
authManager.DeleteSession(sessionID)
}
auth.DeleteSessionCookie(ctx)
// Clear from context
ctx.SetUserValue(SessionKey, nil)
ctx.SetUserValue(UserKey, nil)
}

View File

@ -4,8 +4,12 @@ import (
"fmt"
"log"
"os"
"os/signal"
"path/filepath"
"syscall"
"dk/internal/auth"
"dk/internal/database"
"dk/internal/middleware"
"dk/internal/router"
"dk/internal/template"
@ -21,11 +25,23 @@ func Start(port string) error {
}
templateCache := template.NewCache(cwd)
// Initialize database
db, err := database.Open("dk.sqlite")
if err != nil {
return fmt.Errorf("failed to open database: %w", err)
}
defer db.Close()
// Initialize authentication manager
authManager := auth.NewAuthManager(db, "sessions.json")
defer authManager.Close()
// Initialize router
r := router.New()
// Add timing middleware
// Add middleware
r.Use(middleware.Timing())
r.Use(middleware.Auth(authManager))
// Hello world endpoint
r.Get("/", func(ctx router.Ctx, params []string) {
@ -36,13 +52,28 @@ func Start(port string) error {
return
}
// Get current user if authenticated
currentUser := middleware.GetCurrentUser(ctx)
var username string
if currentUser != nil {
username = currentUser.Username
} else {
username = "Guest"
}
totalSessions, activeSessions := authManager.SessionStats()
data := map[string]any{
"title": "Dragon Knight",
"content": "Hello World!",
"totaltime": middleware.GetRequestTime(ctx),
"numqueries": "0", // Placeholder for now
"version": "1.0.0",
"build": "dev",
"title": "Dragon Knight",
"content": fmt.Sprintf("Hello %s!", username),
"totaltime": middleware.GetRequestTime(ctx),
"numqueries": "0", // Placeholder for now
"version": "1.0.0",
"build": "dev",
"total_sessions": totalSessions,
"active_sessions": activeSessions,
"authenticated": currentUser != nil,
"username": username,
}
tmpl.WriteTo(ctx, data)
@ -78,5 +109,32 @@ func Start(port string) error {
addr := ":" + port
log.Printf("Server starting on %s", addr)
return fasthttp.ListenAndServe(addr, requestHandler)
// Setup graceful shutdown
server := &fasthttp.Server{
Handler: requestHandler,
}
// Channel to listen for interrupt signal
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
// Start server in a goroutine
go func() {
if err := server.ListenAndServe(addr); err != nil {
log.Printf("Server error: %v", err)
}
}()
// Block until we receive a signal
<-c
log.Println("Shutting down server...")
// Shutdown server gracefully
if err := server.Shutdown(); err != nil {
log.Printf("Server shutdown error: %v", err)
}
log.Println("Server stopped")
return nil
}

View File

@ -422,3 +422,47 @@ func (u *User) SetPosition(x, y int) {
u.X = x
u.Y = y
}
// GetByUsername retrieves a user by username
func GetByUsername(db *database.DB, username string) (*User, error) {
var user *User
query := `SELECT ` + userColumns() + ` FROM users WHERE LOWER(username) = LOWER(?) LIMIT 1`
err := db.Query(query, func(stmt *sqlite.Stmt) error {
user = scanUser(stmt, db)
return nil
}, username)
if err != nil {
return nil, fmt.Errorf("query failed: %w", err)
}
if user == nil {
return nil, fmt.Errorf("user not found: %s", username)
}
return user, nil
}
// GetByEmail retrieves a user by email
func GetByEmail(db *database.DB, email string) (*User, error) {
var user *User
query := `SELECT ` + userColumns() + ` FROM users WHERE LOWER(email) = LOWER(?) LIMIT 1`
err := db.Query(query, func(stmt *sqlite.Stmt) error {
user = scanUser(stmt, db)
return nil
}, email)
if err != nil {
return nil, fmt.Errorf("query failed: %w", err)
}
if user == nil {
return nil, fmt.Errorf("user not found: %s", email)
}
return user, nil
}

37
templates/auth/login.html Normal file
View File

@ -0,0 +1,37 @@
{flashhtml}
<form action="/login" method="post">
{csrf}
<table width="75%">
<tr>
<td width="30%">Username:</td>
<td><input type="text" size="30" name="username"></td>
</tr>
<tr>
<td>Password:</td>
<td><input type="password" size="30" name="password"></td>
</tr>
<tr>
<td>Remember me?</td>
<td><input type="checkbox" name="rememberme" value="yes"> Yes</td>
</tr>
<tr>
<td colspan="2"><input type="submit" name="submit" value="Log In"></td>
</tr>
<tr>
<td colspan="2">
Checking the "Remember Me" option will store your login information in a cookie
so you don't have to enter it next time you get online.
<br><br>
Want to play? You gotta <a href="/register">register your own character.</a>
<br><br>
You may also <a href="/change-password">change your password</a>, or
<a href="/lost-password">request a new one</a> if you've lost yours.
</td>
</tr>
</table>
</form>

View File

@ -0,0 +1,60 @@
{flashhtml}
<form action="/register" method="post">
{csrf}
<table width="80%">
<tr>
<td width="20%">Username:</td>
<td>
<input type="text" name="username" size="30" maxlength="30">
<br>
Usernames must be 30 alphanumeric characters or less.
<br><br><br>
</td>
</tr>
<tr>
<td>Password:</td>
<td><input type="password" name="password1" size="30" maxlength="10"></td>
</tr>
<tr>
<td>Verify Password:</td>
<td>
<input type="password" name="password2" size="30" maxlength="10">
<br>
Passwords must be 10 alphanumeric characters or less.
<br><br><br>
</td>
</tr>
<tr>
<td>Email Address:</td>
<td><input type="email" name="email1" size="30" maxlength="100"></td>
</tr>
<tr>
<td>Verify Email:</td>
<td>
<input type="text" name="email2" size="30" maxlength="100">
{verifytext}
<br><br><br>
</td>
</tr>
<tr>
<td>Character Class:</td>
<td>
<select name="charclass">
<option value="1">{class1name}</option>
<option value="2">{class2name}</option>
<option value="3">{class3name}</option>
</select>
</td>
</tr>
<tr>
<td colspan="2">See <a href="/help">Help</a> for more information about character classes.<br><br></td>
</tr>
<tr>
<td colspan="2">
<input type="submit" name="submit" value="Submit">
<input type="reset" name="reset" value="Reset">
</td>
</tr>
</table>
</form>