housekeeping
This commit is contained in:
parent
4a5f0debf6
commit
c5218c6061
@ -1,43 +0,0 @@
|
|||||||
// Package auth provides authentication functionality.
|
|
||||||
// It handles user authentication against the database and password verification.
|
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"dk/internal/models/users"
|
|
||||||
"dk/internal/password"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Authenticate(usernameOrEmail, plainPassword string) (*users.User, error) {
|
|
||||||
var user *users.User
|
|
||||||
var err error
|
|
||||||
|
|
||||||
user, err = users.ByUsername(usernameOrEmail)
|
|
||||||
if err != nil {
|
|
||||||
user, err = users.ByEmail(usernameOrEmail)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
isValid, err := password.Verify(plainPassword, user.Password)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if !isValid {
|
|
||||||
return nil, ErrInvalidCredentials
|
|
||||||
}
|
|
||||||
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrInvalidCredentials = &AuthError{"invalid username/email or password"}
|
|
||||||
)
|
|
||||||
|
|
||||||
type AuthError struct {
|
|
||||||
Message string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *AuthError) Error() string {
|
|
||||||
return e.Message
|
|
||||||
}
|
|
@ -1,7 +1,6 @@
|
|||||||
package cookies
|
package cookies
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
@ -76,16 +75,3 @@ func DeleteCookie(ctx *fasthttp.RequestCtx, name string) {
|
|||||||
SameSite: "lax",
|
SameSite: "lax",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsHTTPS(ctx *fasthttp.RequestCtx) bool {
|
|
||||||
proto := string(ctx.Request.Header.Peek("X-Forwarded-Proto"))
|
|
||||||
if proto == "https" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if bytes.EqualFold(ctx.Request.URI().Scheme(), []byte("https")) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return ctx.IsTLS()
|
|
||||||
}
|
|
||||||
|
@ -1,106 +0,0 @@
|
|||||||
// Package scanner provides fast struct scanning for SQLite results without runtime reflection
|
|
||||||
package scanner
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"zombiezen.com/go/sqlite"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ScanFunc defines how to scan a column into a field
|
|
||||||
type ScanFunc func(stmt *sqlite.Stmt, colIndex int, fieldPtr unsafe.Pointer)
|
|
||||||
|
|
||||||
// Scanner holds pre-compiled scanning information for a struct type
|
|
||||||
type Scanner struct {
|
|
||||||
scanners []ScanFunc
|
|
||||||
offsets []uintptr
|
|
||||||
columns []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Predefined scan functions for common types
|
|
||||||
func scanInt(stmt *sqlite.Stmt, colIndex int, fieldPtr unsafe.Pointer) {
|
|
||||||
*(*int)(fieldPtr) = stmt.ColumnInt(colIndex)
|
|
||||||
}
|
|
||||||
|
|
||||||
func scanInt64(stmt *sqlite.Stmt, colIndex int, fieldPtr unsafe.Pointer) {
|
|
||||||
*(*int64)(fieldPtr) = stmt.ColumnInt64(colIndex)
|
|
||||||
}
|
|
||||||
|
|
||||||
func scanString(stmt *sqlite.Stmt, colIndex int, fieldPtr unsafe.Pointer) {
|
|
||||||
*(*string)(fieldPtr) = stmt.ColumnText(colIndex)
|
|
||||||
}
|
|
||||||
|
|
||||||
func scanFloat64(stmt *sqlite.Stmt, colIndex int, fieldPtr unsafe.Pointer) {
|
|
||||||
*(*float64)(fieldPtr) = stmt.ColumnFloat(colIndex)
|
|
||||||
}
|
|
||||||
|
|
||||||
func scanBool(stmt *sqlite.Stmt, colIndex int, fieldPtr unsafe.Pointer) {
|
|
||||||
*(*bool)(fieldPtr) = stmt.ColumnInt(colIndex) != 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a scanner for the given struct type using reflection once at creation time
|
|
||||||
func New[T any]() *Scanner {
|
|
||||||
var zero T
|
|
||||||
typ := reflect.TypeOf(zero)
|
|
||||||
|
|
||||||
var scanners []ScanFunc
|
|
||||||
var offsets []uintptr
|
|
||||||
var columns []string
|
|
||||||
|
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
|
||||||
field := typ.Field(i)
|
|
||||||
|
|
||||||
// Skip fields without db tag or with "-"
|
|
||||||
dbTag := field.Tag.Get("db")
|
|
||||||
if dbTag == "" || dbTag == "-" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
columns = append(columns, dbTag)
|
|
||||||
offsets = append(offsets, field.Offset)
|
|
||||||
|
|
||||||
// Map field types to scan functions
|
|
||||||
switch field.Type.Kind() {
|
|
||||||
case reflect.Int:
|
|
||||||
scanners = append(scanners, scanInt)
|
|
||||||
case reflect.Int64:
|
|
||||||
scanners = append(scanners, scanInt64)
|
|
||||||
case reflect.String:
|
|
||||||
scanners = append(scanners, scanString)
|
|
||||||
case reflect.Float64:
|
|
||||||
scanners = append(scanners, scanFloat64)
|
|
||||||
case reflect.Bool:
|
|
||||||
scanners = append(scanners, scanBool)
|
|
||||||
default:
|
|
||||||
// Fallback to string for unknown types
|
|
||||||
scanners = append(scanners, scanString)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Scanner{
|
|
||||||
scanners: scanners,
|
|
||||||
offsets: offsets,
|
|
||||||
columns: columns,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Columns returns the comma-separated column list for SQL queries
|
|
||||||
func (s *Scanner) Columns() string {
|
|
||||||
return strings.Join(s.columns, ", ")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scan fills the destination struct with data from the SQLite statement
|
|
||||||
// This method uses no reflection and operates at near-native performance
|
|
||||||
func (s *Scanner) Scan(stmt *sqlite.Stmt, dest any) {
|
|
||||||
// Get pointer to the struct data
|
|
||||||
ptr := (*[2]uintptr)(unsafe.Pointer(&dest))
|
|
||||||
structPtr := unsafe.Pointer(ptr[1])
|
|
||||||
|
|
||||||
// Scan each field using pre-compiled function pointers and offsets
|
|
||||||
for i := 0; i < len(s.scanners); i++ {
|
|
||||||
fieldPtr := unsafe.Add(structPtr, s.offsets[i])
|
|
||||||
s.scanners[i](stmt, i, fieldPtr)
|
|
||||||
}
|
|
||||||
}
|
|
@ -2,6 +2,7 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"dk/internal/cookies"
|
"dk/internal/cookies"
|
||||||
|
"dk/internal/helpers"
|
||||||
"dk/internal/models/users"
|
"dk/internal/models/users"
|
||||||
"dk/internal/router"
|
"dk/internal/router"
|
||||||
"dk/internal/session"
|
"dk/internal/session"
|
||||||
@ -134,7 +135,7 @@ func setSessionCookie(ctx router.Ctx, sessionID string) {
|
|||||||
Path: "/",
|
Path: "/",
|
||||||
Expires: time.Now().Add(24 * time.Hour),
|
Expires: time.Now().Add(24 * time.Hour),
|
||||||
HTTPOnly: true,
|
HTTPOnly: true,
|
||||||
Secure: cookies.IsHTTPS(ctx),
|
Secure: helpers.IsHTTPS(ctx),
|
||||||
SameSite: "lax",
|
SameSite: "lax",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"dk/internal/auth"
|
|
||||||
"dk/internal/csrf"
|
"dk/internal/csrf"
|
||||||
"dk/internal/middleware"
|
"dk/internal/middleware"
|
||||||
"dk/internal/models/users"
|
"dk/internal/models/users"
|
||||||
@ -75,7 +74,7 @@ func processLogin(ctx router.Ctx, _ []string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := auth.Authenticate(email, userPassword)
|
user, err := authenticate(email, userPassword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
setFlashAndFormData(ctx, "Invalid email or password", map[string]string{"id": email})
|
setFlashAndFormData(ctx, "Invalid email or password", map[string]string{"id": email})
|
||||||
ctx.Redirect("/login", fasthttp.StatusFound)
|
ctx.Redirect("/login", fasthttp.StatusFound)
|
||||||
@ -242,3 +241,26 @@ func setFlashAndFormData(ctx router.Ctx, message string, formData map[string]str
|
|||||||
sess.Set("form_data", formData)
|
sess.Set("form_data", formData)
|
||||||
session.Store(sess)
|
session.Store(sess)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func authenticate(usernameOrEmail, plainPassword string) (*users.User, error) {
|
||||||
|
var user *users.User
|
||||||
|
var err error
|
||||||
|
|
||||||
|
user, err = users.ByUsername(usernameOrEmail)
|
||||||
|
if err != nil {
|
||||||
|
user, err = users.ByEmail(usernameOrEmail)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
isValid, err := password.Verify(plainPassword, user.Password)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !isValid {
|
||||||
|
return nil, fmt.Errorf("invalid username/email or password")
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user