Simplify auth package

This commit is contained in:
Sky Johnson 2025-08-09 18:12:23 -05:00
parent 80700149f8
commit b1f436585e
5 changed files with 110 additions and 109 deletions

View File

@ -1,7 +1,6 @@
package auth package auth
import ( import (
"dk/internal/database"
"dk/internal/password" "dk/internal/password"
"dk/internal/users" "dk/internal/users"
) )
@ -9,29 +8,28 @@ import (
// Manager is the global singleton instance // Manager is the global singleton instance
var Manager *AuthManager var Manager *AuthManager
// User is a simplified User struct for auth purposes
type User struct { type User struct {
ID int ID int
Username string Username string
Email string Email string
} }
// AuthManager is a wrapper for the session store to add
// authentication tools over the store itself
type AuthManager struct { type AuthManager struct {
sessionStore *SessionStore store *SessionStore
db *database.DB
} }
func NewAuthManager(db *database.DB, sessionsFilePath string) *AuthManager { // Init initializes the global auth manager (auth.Manager)
return &AuthManager{ func Init(sessionsFilePath string) {
sessionStore: NewSessionStore(sessionsFilePath), Manager = &AuthManager{
db: db, store: NewSessionStore(sessionsFilePath),
} }
} }
// InitializeManager initializes the global Manager singleton // Authenticate checks for the usernaname or email, then verifies the plain password
func InitializeManager(db *database.DB, sessionsFilePath string) { // against the stored hash.
Manager = NewAuthManager(db, sessionsFilePath)
}
func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*User, error) { func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*User, error) {
var user *users.User var user *users.User
var err error var err error
@ -39,14 +37,12 @@ func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*Use
// Try to find user by username first // Try to find user by username first
user, err = users.GetByUsername(usernameOrEmail) user, err = users.GetByUsername(usernameOrEmail)
if err != nil { if err != nil {
// Try by email if username lookup failed
user, err = users.GetByEmail(usernameOrEmail) user, err = users.GetByEmail(usernameOrEmail)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
// Verify password
isValid, err := password.Verify(plainPassword, user.Password) isValid, err := password.Verify(plainPassword, user.Password)
if err != nil { if err != nil {
return nil, err return nil, err
@ -63,31 +59,27 @@ func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*Use
} }
func (am *AuthManager) CreateSession(user *User) *Session { func (am *AuthManager) CreateSession(user *User) *Session {
return am.sessionStore.Create(user.ID, user.Username, user.Email) return am.store.Create(user.ID, user.Username, user.Email)
} }
func (am *AuthManager) GetSession(sessionID string) (*Session, bool) { func (am *AuthManager) GetSession(sessionID string) (*Session, bool) {
return am.sessionStore.Get(sessionID) return am.store.Get(sessionID)
} }
func (am *AuthManager) UpdateSession(sessionID string) bool { func (am *AuthManager) UpdateSession(sessionID string) bool {
return am.sessionStore.Update(sessionID) return am.store.Update(sessionID)
} }
func (am *AuthManager) DeleteSession(sessionID string) { func (am *AuthManager) DeleteSession(sessionID string) {
am.sessionStore.Delete(sessionID) am.store.Delete(sessionID)
} }
func (am *AuthManager) SessionStats() (total, active int) { func (am *AuthManager) SessionStats() (total, active int) {
return am.sessionStore.Stats() return am.store.Stats()
}
func (am *AuthManager) DB() *database.DB {
return am.db
} }
func (am *AuthManager) Close() error { func (am *AuthManager) Close() error {
return am.sessionStore.Close() return am.store.Close()
} }
var ( var (

View File

@ -1,103 +1,29 @@
package auth package auth
import ( import (
"dk/internal/cookies"
"dk/internal/utils"
"time" "time"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
type CookieOptions struct {
Name string
Value string
Path string
Domain string
Expires time.Time
MaxAge int
Secure bool
HTTPOnly bool
SameSite string
}
func SetSecureCookie(ctx *fasthttp.RequestCtx, opts CookieOptions) {
cookie := &fasthttp.Cookie{}
cookie.SetKey(opts.Name)
cookie.SetValue(opts.Value)
if opts.Path != "" {
cookie.SetPath(opts.Path)
} else {
cookie.SetPath("/")
}
if opts.Domain != "" {
cookie.SetDomain(opts.Domain)
}
if !opts.Expires.IsZero() {
cookie.SetExpire(opts.Expires)
}
if opts.MaxAge > 0 {
cookie.SetMaxAge(opts.MaxAge)
}
cookie.SetSecure(opts.Secure)
cookie.SetHTTPOnly(opts.HTTPOnly)
switch opts.SameSite {
case "strict":
cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
case "lax":
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
case "none":
cookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
default:
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
}
ctx.Response.Header.SetCookie(cookie)
}
func GetCookie(ctx *fasthttp.RequestCtx, name string) string {
return string(ctx.Request.Header.Cookie(name))
}
func DeleteCookie(ctx *fasthttp.RequestCtx, name string) {
SetSecureCookie(ctx, CookieOptions{
Name: name,
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
MaxAge: -1,
HTTPOnly: true,
Secure: true,
SameSite: "lax",
})
}
func SetSessionCookie(ctx *fasthttp.RequestCtx, sessionID string) { func SetSessionCookie(ctx *fasthttp.RequestCtx, sessionID string) {
SetSecureCookie(ctx, CookieOptions{ cookies.SetSecureCookie(ctx, cookies.CookieOptions{
Name: SessionCookieName, Name: SessionCookieName,
Value: sessionID, Value: sessionID,
Path: "/", Path: "/",
Expires: time.Now().Add(DefaultExpiration), Expires: time.Now().Add(DefaultExpiration),
HTTPOnly: true, HTTPOnly: true,
Secure: isHTTPS(ctx), Secure: utils.IsHTTPS(ctx),
SameSite: "lax", SameSite: "lax",
}) })
} }
func GetSessionCookie(ctx *fasthttp.RequestCtx) string { func GetSessionCookie(ctx *fasthttp.RequestCtx) string {
return GetCookie(ctx, SessionCookieName) return cookies.GetCookie(ctx, SessionCookieName)
} }
func DeleteSessionCookie(ctx *fasthttp.RequestCtx) { func DeleteSessionCookie(ctx *fasthttp.RequestCtx) {
DeleteCookie(ctx, SessionCookieName) cookies.DeleteCookie(ctx, SessionCookieName)
} }
func isHTTPS(ctx *fasthttp.RequestCtx) bool {
return ctx.IsTLS() ||
string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" ||
string(ctx.Request.Header.Peek("X-Forwarded-Scheme")) == "https"
}

View File

@ -0,0 +1,77 @@
package cookies
import (
"time"
"github.com/valyala/fasthttp"
)
type CookieOptions struct {
Name string
Value string
Path string
Domain string
Expires time.Time
MaxAge int
Secure bool
HTTPOnly bool
SameSite string
}
func SetSecureCookie(ctx *fasthttp.RequestCtx, opts CookieOptions) {
cookie := &fasthttp.Cookie{}
cookie.SetKey(opts.Name)
cookie.SetValue(opts.Value)
if opts.Path != "" {
cookie.SetPath(opts.Path)
} else {
cookie.SetPath("/")
}
if opts.Domain != "" {
cookie.SetDomain(opts.Domain)
}
if !opts.Expires.IsZero() {
cookie.SetExpire(opts.Expires)
}
if opts.MaxAge > 0 {
cookie.SetMaxAge(opts.MaxAge)
}
cookie.SetSecure(opts.Secure)
cookie.SetHTTPOnly(opts.HTTPOnly)
switch opts.SameSite {
case "strict":
cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
case "lax":
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
case "none":
cookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
default:
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
}
ctx.Response.Header.SetCookie(cookie)
}
func GetCookie(ctx *fasthttp.RequestCtx, name string) string {
return string(ctx.Request.Header.Cookie(name))
}
func DeleteCookie(ctx *fasthttp.RequestCtx, name string) {
SetSecureCookie(ctx, CookieOptions{
Name: name,
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
MaxAge: -1,
HTTPOnly: true,
Secure: true,
SameSite: "lax",
})
}

View File

@ -33,13 +33,9 @@ func Start(port string) error {
} }
defer database.Close() defer database.Close()
// Initialize auth singleton auth.Init("sessions.json") // Initialize auth.Manager
auth.InitializeManager(database.GetDB(), "sessions.json")
// Initialize router
r := router.New() r := router.New()
// Add middleware
r.Use(middleware.Timing()) r.Use(middleware.Timing())
r.Use(middleware.Auth(auth.Manager)) r.Use(middleware.Auth(auth.Manager))
r.Use(middleware.CSRF(auth.Manager)) r.Use(middleware.CSRF(auth.Manager))

10
internal/utils/http.go Normal file
View File

@ -0,0 +1,10 @@
package utils
import "github.com/valyala/fasthttp"
// IsHTTPS tries to determine if the current request context is over HTTPS
func IsHTTPS(ctx *fasthttp.RequestCtx) bool {
return ctx.IsTLS() ||
string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" ||
string(ctx.Request.Header.Peek("X-Forwarded-Scheme")) == "https"
}