From b1f436585efc30e5944291084fa546ca9ef27743 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Sat, 9 Aug 2025 18:12:23 -0500 Subject: [PATCH] Simplify auth package --- internal/auth/auth.go | 40 +++++++---------- internal/auth/cookies.go | 86 +++---------------------------------- internal/cookies/cookies.go | 77 +++++++++++++++++++++++++++++++++ internal/server/server.go | 6 +-- internal/utils/http.go | 10 +++++ 5 files changed, 110 insertions(+), 109 deletions(-) create mode 100644 internal/cookies/cookies.go create mode 100644 internal/utils/http.go diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 636897a..0f67f1f 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -1,7 +1,6 @@ package auth import ( - "dk/internal/database" "dk/internal/password" "dk/internal/users" ) @@ -9,29 +8,28 @@ import ( // Manager is the global singleton instance var Manager *AuthManager +// User is a simplified User struct for auth purposes type User struct { ID int Username string Email string } +// AuthManager is a wrapper for the session store to add +// authentication tools over the store itself type AuthManager struct { - sessionStore *SessionStore - db *database.DB + store *SessionStore } -func NewAuthManager(db *database.DB, sessionsFilePath string) *AuthManager { - return &AuthManager{ - sessionStore: NewSessionStore(sessionsFilePath), - db: db, +// Init initializes the global auth manager (auth.Manager) +func Init(sessionsFilePath string) { + Manager = &AuthManager{ + store: NewSessionStore(sessionsFilePath), } } -// InitializeManager initializes the global Manager singleton -func InitializeManager(db *database.DB, sessionsFilePath string) { - Manager = NewAuthManager(db, sessionsFilePath) -} - +// Authenticate checks for the usernaname or email, then verifies the plain password +// against the stored hash. func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*User, error) { var user *users.User var err error @@ -39,14 +37,12 @@ func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*Use // Try to find user by username first user, err = users.GetByUsername(usernameOrEmail) if err != nil { - // Try by email if username lookup failed user, err = users.GetByEmail(usernameOrEmail) if err != nil { return nil, err } } - // Verify password isValid, err := password.Verify(plainPassword, user.Password) if err != nil { return nil, err @@ -63,31 +59,27 @@ func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*Use } func (am *AuthManager) CreateSession(user *User) *Session { - return am.sessionStore.Create(user.ID, user.Username, user.Email) + return am.store.Create(user.ID, user.Username, user.Email) } func (am *AuthManager) GetSession(sessionID string) (*Session, bool) { - return am.sessionStore.Get(sessionID) + return am.store.Get(sessionID) } func (am *AuthManager) UpdateSession(sessionID string) bool { - return am.sessionStore.Update(sessionID) + return am.store.Update(sessionID) } func (am *AuthManager) DeleteSession(sessionID string) { - am.sessionStore.Delete(sessionID) + am.store.Delete(sessionID) } func (am *AuthManager) SessionStats() (total, active int) { - return am.sessionStore.Stats() -} - -func (am *AuthManager) DB() *database.DB { - return am.db + return am.store.Stats() } func (am *AuthManager) Close() error { - return am.sessionStore.Close() + return am.store.Close() } var ( diff --git a/internal/auth/cookies.go b/internal/auth/cookies.go index 437b624..519f410 100644 --- a/internal/auth/cookies.go +++ b/internal/auth/cookies.go @@ -1,103 +1,29 @@ package auth import ( + "dk/internal/cookies" + "dk/internal/utils" "time" "github.com/valyala/fasthttp" ) -type CookieOptions struct { - Name string - Value string - Path string - Domain string - Expires time.Time - MaxAge int - Secure bool - HTTPOnly bool - SameSite string -} - -func SetSecureCookie(ctx *fasthttp.RequestCtx, opts CookieOptions) { - cookie := &fasthttp.Cookie{} - - cookie.SetKey(opts.Name) - cookie.SetValue(opts.Value) - - if opts.Path != "" { - cookie.SetPath(opts.Path) - } else { - cookie.SetPath("/") - } - - if opts.Domain != "" { - cookie.SetDomain(opts.Domain) - } - - if !opts.Expires.IsZero() { - cookie.SetExpire(opts.Expires) - } - - if opts.MaxAge > 0 { - cookie.SetMaxAge(opts.MaxAge) - } - - cookie.SetSecure(opts.Secure) - cookie.SetHTTPOnly(opts.HTTPOnly) - - switch opts.SameSite { - case "strict": - cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode) - case "lax": - cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode) - case "none": - cookie.SetSameSite(fasthttp.CookieSameSiteNoneMode) - default: - cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode) - } - - ctx.Response.Header.SetCookie(cookie) -} - -func GetCookie(ctx *fasthttp.RequestCtx, name string) string { - return string(ctx.Request.Header.Cookie(name)) -} - -func DeleteCookie(ctx *fasthttp.RequestCtx, name string) { - SetSecureCookie(ctx, CookieOptions{ - Name: name, - Value: "", - Path: "/", - Expires: time.Unix(0, 0), - MaxAge: -1, - HTTPOnly: true, - Secure: true, - SameSite: "lax", - }) -} - func SetSessionCookie(ctx *fasthttp.RequestCtx, sessionID string) { - SetSecureCookie(ctx, CookieOptions{ + cookies.SetSecureCookie(ctx, cookies.CookieOptions{ Name: SessionCookieName, Value: sessionID, Path: "/", Expires: time.Now().Add(DefaultExpiration), HTTPOnly: true, - Secure: isHTTPS(ctx), + Secure: utils.IsHTTPS(ctx), SameSite: "lax", }) } func GetSessionCookie(ctx *fasthttp.RequestCtx) string { - return GetCookie(ctx, SessionCookieName) + return cookies.GetCookie(ctx, SessionCookieName) } func DeleteSessionCookie(ctx *fasthttp.RequestCtx) { - DeleteCookie(ctx, SessionCookieName) + cookies.DeleteCookie(ctx, SessionCookieName) } - -func isHTTPS(ctx *fasthttp.RequestCtx) bool { - return ctx.IsTLS() || - string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" || - string(ctx.Request.Header.Peek("X-Forwarded-Scheme")) == "https" -} \ No newline at end of file diff --git a/internal/cookies/cookies.go b/internal/cookies/cookies.go new file mode 100644 index 0000000..d3866b4 --- /dev/null +++ b/internal/cookies/cookies.go @@ -0,0 +1,77 @@ +package cookies + +import ( + "time" + + "github.com/valyala/fasthttp" +) + +type CookieOptions struct { + Name string + Value string + Path string + Domain string + Expires time.Time + MaxAge int + Secure bool + HTTPOnly bool + SameSite string +} + +func SetSecureCookie(ctx *fasthttp.RequestCtx, opts CookieOptions) { + cookie := &fasthttp.Cookie{} + + cookie.SetKey(opts.Name) + cookie.SetValue(opts.Value) + + if opts.Path != "" { + cookie.SetPath(opts.Path) + } else { + cookie.SetPath("/") + } + + if opts.Domain != "" { + cookie.SetDomain(opts.Domain) + } + + if !opts.Expires.IsZero() { + cookie.SetExpire(opts.Expires) + } + + if opts.MaxAge > 0 { + cookie.SetMaxAge(opts.MaxAge) + } + + cookie.SetSecure(opts.Secure) + cookie.SetHTTPOnly(opts.HTTPOnly) + + switch opts.SameSite { + case "strict": + cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode) + case "lax": + cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode) + case "none": + cookie.SetSameSite(fasthttp.CookieSameSiteNoneMode) + default: + cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode) + } + + ctx.Response.Header.SetCookie(cookie) +} + +func GetCookie(ctx *fasthttp.RequestCtx, name string) string { + return string(ctx.Request.Header.Cookie(name)) +} + +func DeleteCookie(ctx *fasthttp.RequestCtx, name string) { + SetSecureCookie(ctx, CookieOptions{ + Name: name, + Value: "", + Path: "/", + Expires: time.Unix(0, 0), + MaxAge: -1, + HTTPOnly: true, + Secure: true, + SameSite: "lax", + }) +} diff --git a/internal/server/server.go b/internal/server/server.go index 53c1c52..5404c41 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -33,13 +33,9 @@ func Start(port string) error { } defer database.Close() - // Initialize auth singleton - auth.InitializeManager(database.GetDB(), "sessions.json") + auth.Init("sessions.json") // Initialize auth.Manager - // Initialize router r := router.New() - - // Add middleware r.Use(middleware.Timing()) r.Use(middleware.Auth(auth.Manager)) r.Use(middleware.CSRF(auth.Manager)) diff --git a/internal/utils/http.go b/internal/utils/http.go new file mode 100644 index 0000000..516c737 --- /dev/null +++ b/internal/utils/http.go @@ -0,0 +1,10 @@ +package utils + +import "github.com/valyala/fasthttp" + +// IsHTTPS tries to determine if the current request context is over HTTPS +func IsHTTPS(ctx *fasthttp.RequestCtx) bool { + return ctx.IsTLS() || + string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" || + string(ctx.Request.Header.Peek("X-Forwarded-Scheme")) == "https" +}