auth system first pass
This commit is contained in:
parent
a1e8d49c0e
commit
a49346160b
93
internal/auth/auth.go
Normal file
93
internal/auth/auth.go
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"dk/internal/database"
|
||||||
|
"dk/internal/password"
|
||||||
|
"dk/internal/users"
|
||||||
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID int
|
||||||
|
Username string
|
||||||
|
Email string
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthManager struct {
|
||||||
|
sessionStore *SessionStore
|
||||||
|
db *database.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAuthManager(db *database.DB, sessionsFilePath string) *AuthManager {
|
||||||
|
return &AuthManager{
|
||||||
|
sessionStore: NewSessionStore(sessionsFilePath),
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*User, error) {
|
||||||
|
var user *users.User
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Try to find user by username first
|
||||||
|
user, err = users.GetByUsername(am.db, usernameOrEmail)
|
||||||
|
if err != nil {
|
||||||
|
// Try by email if username lookup failed
|
||||||
|
user, err = users.GetByEmail(am.db, usernameOrEmail)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify password
|
||||||
|
isValid, err := password.Verify(user.Password, plainPassword)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !isValid {
|
||||||
|
return nil, ErrInvalidCredentials
|
||||||
|
}
|
||||||
|
|
||||||
|
return &User{
|
||||||
|
ID: user.ID,
|
||||||
|
Username: user.Username,
|
||||||
|
Email: user.Email,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *AuthManager) CreateSession(user *User) *Session {
|
||||||
|
return am.sessionStore.Create(user.ID, user.Username, user.Email)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *AuthManager) GetSession(sessionID string) (*Session, bool) {
|
||||||
|
return am.sessionStore.Get(sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *AuthManager) UpdateSession(sessionID string) bool {
|
||||||
|
return am.sessionStore.Update(sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *AuthManager) DeleteSession(sessionID string) {
|
||||||
|
am.sessionStore.Delete(sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *AuthManager) SessionStats() (total, active int) {
|
||||||
|
return am.sessionStore.Stats()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *AuthManager) Close() error {
|
||||||
|
return am.sessionStore.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidCredentials = &AuthError{"invalid username/email or password"}
|
||||||
|
ErrSessionNotFound = &AuthError{"session not found"}
|
||||||
|
ErrSessionExpired = &AuthError{"session expired"}
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthError struct {
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AuthError) Error() string {
|
||||||
|
return e.Message
|
||||||
|
}
|
103
internal/auth/cookies.go
Normal file
103
internal/auth/cookies.go
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CookieOptions struct {
|
||||||
|
Name string
|
||||||
|
Value string
|
||||||
|
Path string
|
||||||
|
Domain string
|
||||||
|
Expires time.Time
|
||||||
|
MaxAge int
|
||||||
|
Secure bool
|
||||||
|
HTTPOnly bool
|
||||||
|
SameSite string
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetSecureCookie(ctx *fasthttp.RequestCtx, opts CookieOptions) {
|
||||||
|
cookie := &fasthttp.Cookie{}
|
||||||
|
|
||||||
|
cookie.SetKey(opts.Name)
|
||||||
|
cookie.SetValue(opts.Value)
|
||||||
|
|
||||||
|
if opts.Path != "" {
|
||||||
|
cookie.SetPath(opts.Path)
|
||||||
|
} else {
|
||||||
|
cookie.SetPath("/")
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.Domain != "" {
|
||||||
|
cookie.SetDomain(opts.Domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !opts.Expires.IsZero() {
|
||||||
|
cookie.SetExpire(opts.Expires)
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.MaxAge > 0 {
|
||||||
|
cookie.SetMaxAge(opts.MaxAge)
|
||||||
|
}
|
||||||
|
|
||||||
|
cookie.SetSecure(opts.Secure)
|
||||||
|
cookie.SetHTTPOnly(opts.HTTPOnly)
|
||||||
|
|
||||||
|
switch opts.SameSite {
|
||||||
|
case "strict":
|
||||||
|
cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
|
||||||
|
case "lax":
|
||||||
|
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
|
||||||
|
case "none":
|
||||||
|
cookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
|
||||||
|
default:
|
||||||
|
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Response.Header.SetCookie(cookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetCookie(ctx *fasthttp.RequestCtx, name string) string {
|
||||||
|
return string(ctx.Request.Header.Cookie(name))
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteCookie(ctx *fasthttp.RequestCtx, name string) {
|
||||||
|
SetSecureCookie(ctx, CookieOptions{
|
||||||
|
Name: name,
|
||||||
|
Value: "",
|
||||||
|
Path: "/",
|
||||||
|
Expires: time.Unix(0, 0),
|
||||||
|
MaxAge: -1,
|
||||||
|
HTTPOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
SameSite: "lax",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetSessionCookie(ctx *fasthttp.RequestCtx, sessionID string) {
|
||||||
|
SetSecureCookie(ctx, CookieOptions{
|
||||||
|
Name: SessionCookieName,
|
||||||
|
Value: sessionID,
|
||||||
|
Path: "/",
|
||||||
|
Expires: time.Now().Add(DefaultExpiration),
|
||||||
|
HTTPOnly: true,
|
||||||
|
Secure: isHTTPS(ctx),
|
||||||
|
SameSite: "lax",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSessionCookie(ctx *fasthttp.RequestCtx) string {
|
||||||
|
return GetCookie(ctx, SessionCookieName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteSessionCookie(ctx *fasthttp.RequestCtx) {
|
||||||
|
DeleteCookie(ctx, SessionCookieName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isHTTPS(ctx *fasthttp.RequestCtx) bool {
|
||||||
|
return ctx.IsTLS() ||
|
||||||
|
string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" ||
|
||||||
|
string(ctx.Request.Header.Peek("X-Forwarded-Scheme")) == "https"
|
||||||
|
}
|
4
internal/auth/doc.go
Normal file
4
internal/auth/doc.go
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
// Package auth provides authentication and session management functionality.
|
||||||
|
// It includes secure session storage with in-memory caching and JSON persistence,
|
||||||
|
// user authentication against the database, and secure cookie handling.
|
||||||
|
package auth
|
221
internal/auth/session.go
Normal file
221
internal/auth/session.go
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"maps"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
SessionCookieName = "dk_session"
|
||||||
|
DefaultExpiration = 24 * time.Hour
|
||||||
|
SessionIDLength = 32
|
||||||
|
)
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
UserID int `json:"user_id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
LastSeen time.Time `json:"last_seen"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SessionStore struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
sessions map[string]*Session
|
||||||
|
filePath string
|
||||||
|
saveInterval time.Duration
|
||||||
|
stopChan chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type persistedData struct {
|
||||||
|
Sessions map[string]*Session `json:"sessions"`
|
||||||
|
SavedAt time.Time `json:"saved_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSessionStore(filePath string) *SessionStore {
|
||||||
|
store := &SessionStore{
|
||||||
|
sessions: make(map[string]*Session),
|
||||||
|
filePath: filePath,
|
||||||
|
saveInterval: 5 * time.Minute,
|
||||||
|
stopChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
store.loadFromFile()
|
||||||
|
store.startPeriodicSave()
|
||||||
|
|
||||||
|
return store
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionStore) generateSessionID() string {
|
||||||
|
bytes := make([]byte, SessionIDLength)
|
||||||
|
rand.Read(bytes)
|
||||||
|
return hex.EncodeToString(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionStore) Create(userID int, username, email string) *Session {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
session := &Session{
|
||||||
|
ID: s.generateSessionID(),
|
||||||
|
UserID: userID,
|
||||||
|
Username: username,
|
||||||
|
Email: email,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
ExpiresAt: time.Now().Add(DefaultExpiration),
|
||||||
|
LastSeen: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sessions[session.ID] = session
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionStore) Get(sessionID string) (*Session, bool) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
session, exists := s.sessions[sessionID]
|
||||||
|
if !exists {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().After(session.ExpiresAt) {
|
||||||
|
delete(s.sessions, sessionID)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return session, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionStore) Update(sessionID string) bool {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
session, exists := s.sessions[sessionID]
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().After(session.ExpiresAt) {
|
||||||
|
delete(s.sessions, sessionID)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
session.LastSeen = time.Now()
|
||||||
|
session.ExpiresAt = time.Now().Add(DefaultExpiration)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionStore) Delete(sessionID string) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
delete(s.sessions, sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionStore) Cleanup() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for id, session := range s.sessions {
|
||||||
|
if now.After(session.ExpiresAt) {
|
||||||
|
delete(s.sessions, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionStore) loadFromFile() {
|
||||||
|
if s.filePath == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(s.filePath)
|
||||||
|
if err != nil {
|
||||||
|
return // File might not exist yet
|
||||||
|
}
|
||||||
|
|
||||||
|
var persisted persistedData
|
||||||
|
if err := json.Unmarshal(data, &persisted); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for id, session := range persisted.Sessions {
|
||||||
|
if now.Before(session.ExpiresAt) {
|
||||||
|
s.sessions[id] = session
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionStore) saveToFile() error {
|
||||||
|
if s.filePath == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.RLock()
|
||||||
|
sessionsCopy := make(map[string]*Session)
|
||||||
|
maps.Copy(sessionsCopy, s.sessions)
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
data := persistedData{
|
||||||
|
Sessions: sessionsCopy,
|
||||||
|
SavedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.MarshalIndent(data, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.WriteFile(s.filePath, jsonData, 0600)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionStore) startPeriodicSave() {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(s.saveInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
s.Cleanup()
|
||||||
|
s.saveToFile()
|
||||||
|
case <-s.stopChan:
|
||||||
|
s.saveToFile()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionStore) Close() error {
|
||||||
|
close(s.stopChan)
|
||||||
|
return s.saveToFile()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionStore) Stats() (total, active int) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
total = len(s.sessions)
|
||||||
|
|
||||||
|
for _, session := range s.sessions {
|
||||||
|
if now.Before(session.ExpiresAt) {
|
||||||
|
active++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
@ -57,6 +57,16 @@ func (db *DB) Close() error {
|
|||||||
return db.pool.Close()
|
return db.pool.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetConn gets a connection from the pool - caller must call Put when done
|
||||||
|
func (db *DB) GetConn(ctx context.Context) (*sqlite.Conn, error) {
|
||||||
|
return db.pool.Take(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutConn returns a connection to the pool
|
||||||
|
func (db *DB) PutConn(conn *sqlite.Conn) {
|
||||||
|
db.pool.Put(conn)
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes a SQL statement without returning results
|
// Exec executes a SQL statement without returning results
|
||||||
func (db *DB) Exec(query string, args ...any) error {
|
func (db *DB) Exec(query string, args ...any) error {
|
||||||
conn, err := db.pool.Take(context.Background())
|
conn, err := db.pool.Take(context.Background())
|
||||||
|
116
internal/middleware/auth.go
Normal file
116
internal/middleware/auth.go
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"dk/internal/auth"
|
||||||
|
"dk/internal/router"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
UserKey = "user"
|
||||||
|
SessionKey = "session"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Auth creates an authentication middleware
|
||||||
|
func Auth(authManager *auth.AuthManager) router.Middleware {
|
||||||
|
return func(next router.Handler) router.Handler {
|
||||||
|
return func(ctx router.Ctx, params []string) {
|
||||||
|
sessionID := auth.GetSessionCookie(ctx)
|
||||||
|
|
||||||
|
if sessionID != "" {
|
||||||
|
if session, exists := authManager.GetSession(sessionID); exists {
|
||||||
|
// Update session activity
|
||||||
|
authManager.UpdateSession(sessionID)
|
||||||
|
|
||||||
|
// Store session and user info in context
|
||||||
|
ctx.SetUserValue(SessionKey, session)
|
||||||
|
ctx.SetUserValue(UserKey, &auth.User{
|
||||||
|
ID: session.UserID,
|
||||||
|
Username: session.Username,
|
||||||
|
Email: session.Email,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Refresh the cookie
|
||||||
|
auth.SetSessionCookie(ctx, sessionID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
next(ctx, params)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireAuth enforces authentication - redirects to login if not authenticated
|
||||||
|
func RequireAuth(loginPath string) router.Middleware {
|
||||||
|
return func(next router.Handler) router.Handler {
|
||||||
|
return func(ctx router.Ctx, params []string) {
|
||||||
|
if !IsAuthenticated(ctx) {
|
||||||
|
ctx.Redirect(loginPath, fasthttp.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next(ctx, params)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireGuest enforces no authentication - redirects to dashboard if authenticated
|
||||||
|
func RequireGuest(dashboardPath string) router.Middleware {
|
||||||
|
return func(next router.Handler) router.Handler {
|
||||||
|
return func(ctx router.Ctx, params []string) {
|
||||||
|
if IsAuthenticated(ctx) {
|
||||||
|
ctx.Redirect(dashboardPath, fasthttp.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next(ctx, params)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAuthenticated checks if the current request has a valid session
|
||||||
|
func IsAuthenticated(ctx router.Ctx) bool {
|
||||||
|
_, exists := ctx.UserValue(UserKey).(*auth.User)
|
||||||
|
return exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCurrentUser returns the current authenticated user, or nil if not authenticated
|
||||||
|
func GetCurrentUser(ctx router.Ctx) *auth.User {
|
||||||
|
if user, ok := ctx.UserValue(UserKey).(*auth.User); ok {
|
||||||
|
return user
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCurrentSession returns the current session, or nil if not authenticated
|
||||||
|
func GetCurrentSession(ctx router.Ctx) *auth.Session {
|
||||||
|
if session, ok := ctx.UserValue(SessionKey).(*auth.Session); ok {
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login creates a session and sets the cookie
|
||||||
|
func Login(ctx router.Ctx, authManager *auth.AuthManager, user *auth.User) {
|
||||||
|
session := authManager.CreateSession(user)
|
||||||
|
auth.SetSessionCookie(ctx, session.ID)
|
||||||
|
|
||||||
|
// Set in context for immediate use
|
||||||
|
ctx.SetUserValue(SessionKey, session)
|
||||||
|
ctx.SetUserValue(UserKey, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logout destroys the session and clears the cookie
|
||||||
|
func Logout(ctx router.Ctx, authManager *auth.AuthManager) {
|
||||||
|
sessionID := auth.GetSessionCookie(ctx)
|
||||||
|
if sessionID != "" {
|
||||||
|
authManager.DeleteSession(sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth.DeleteSessionCookie(ctx)
|
||||||
|
|
||||||
|
// Clear from context
|
||||||
|
ctx.SetUserValue(SessionKey, nil)
|
||||||
|
ctx.SetUserValue(UserKey, nil)
|
||||||
|
}
|
@ -4,8 +4,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"dk/internal/auth"
|
||||||
|
"dk/internal/database"
|
||||||
"dk/internal/middleware"
|
"dk/internal/middleware"
|
||||||
"dk/internal/router"
|
"dk/internal/router"
|
||||||
"dk/internal/template"
|
"dk/internal/template"
|
||||||
@ -21,11 +25,23 @@ func Start(port string) error {
|
|||||||
}
|
}
|
||||||
templateCache := template.NewCache(cwd)
|
templateCache := template.NewCache(cwd)
|
||||||
|
|
||||||
|
// Initialize database
|
||||||
|
db, err := database.Open("dk.sqlite")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open database: %w", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
// Initialize authentication manager
|
||||||
|
authManager := auth.NewAuthManager(db, "sessions.json")
|
||||||
|
defer authManager.Close()
|
||||||
|
|
||||||
// Initialize router
|
// Initialize router
|
||||||
r := router.New()
|
r := router.New()
|
||||||
|
|
||||||
// Add timing middleware
|
// Add middleware
|
||||||
r.Use(middleware.Timing())
|
r.Use(middleware.Timing())
|
||||||
|
r.Use(middleware.Auth(authManager))
|
||||||
|
|
||||||
// Hello world endpoint
|
// Hello world endpoint
|
||||||
r.Get("/", func(ctx router.Ctx, params []string) {
|
r.Get("/", func(ctx router.Ctx, params []string) {
|
||||||
@ -36,13 +52,28 @@ func Start(port string) error {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get current user if authenticated
|
||||||
|
currentUser := middleware.GetCurrentUser(ctx)
|
||||||
|
var username string
|
||||||
|
if currentUser != nil {
|
||||||
|
username = currentUser.Username
|
||||||
|
} else {
|
||||||
|
username = "Guest"
|
||||||
|
}
|
||||||
|
|
||||||
|
totalSessions, activeSessions := authManager.SessionStats()
|
||||||
|
|
||||||
data := map[string]any{
|
data := map[string]any{
|
||||||
"title": "Dragon Knight",
|
"title": "Dragon Knight",
|
||||||
"content": "Hello World!",
|
"content": fmt.Sprintf("Hello %s!", username),
|
||||||
"totaltime": middleware.GetRequestTime(ctx),
|
"totaltime": middleware.GetRequestTime(ctx),
|
||||||
"numqueries": "0", // Placeholder for now
|
"numqueries": "0", // Placeholder for now
|
||||||
"version": "1.0.0",
|
"version": "1.0.0",
|
||||||
"build": "dev",
|
"build": "dev",
|
||||||
|
"total_sessions": totalSessions,
|
||||||
|
"active_sessions": activeSessions,
|
||||||
|
"authenticated": currentUser != nil,
|
||||||
|
"username": username,
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpl.WriteTo(ctx, data)
|
tmpl.WriteTo(ctx, data)
|
||||||
@ -78,5 +109,32 @@ func Start(port string) error {
|
|||||||
|
|
||||||
addr := ":" + port
|
addr := ":" + port
|
||||||
log.Printf("Server starting on %s", addr)
|
log.Printf("Server starting on %s", addr)
|
||||||
return fasthttp.ListenAndServe(addr, requestHandler)
|
|
||||||
|
// Setup graceful shutdown
|
||||||
|
server := &fasthttp.Server{
|
||||||
|
Handler: requestHandler,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Channel to listen for interrupt signal
|
||||||
|
c := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||||
|
|
||||||
|
// Start server in a goroutine
|
||||||
|
go func() {
|
||||||
|
if err := server.ListenAndServe(addr); err != nil {
|
||||||
|
log.Printf("Server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Block until we receive a signal
|
||||||
|
<-c
|
||||||
|
log.Println("Shutting down server...")
|
||||||
|
|
||||||
|
// Shutdown server gracefully
|
||||||
|
if err := server.Shutdown(); err != nil {
|
||||||
|
log.Printf("Server shutdown error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("Server stopped")
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -422,3 +422,47 @@ func (u *User) SetPosition(x, y int) {
|
|||||||
u.X = x
|
u.X = x
|
||||||
u.Y = y
|
u.Y = y
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetByUsername retrieves a user by username
|
||||||
|
func GetByUsername(db *database.DB, username string) (*User, error) {
|
||||||
|
var user *User
|
||||||
|
|
||||||
|
query := `SELECT ` + userColumns() + ` FROM users WHERE LOWER(username) = LOWER(?) LIMIT 1`
|
||||||
|
|
||||||
|
err := db.Query(query, func(stmt *sqlite.Stmt) error {
|
||||||
|
user = scanUser(stmt, db)
|
||||||
|
return nil
|
||||||
|
}, username)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("query failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if user == nil {
|
||||||
|
return nil, fmt.Errorf("user not found: %s", username)
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByEmail retrieves a user by email
|
||||||
|
func GetByEmail(db *database.DB, email string) (*User, error) {
|
||||||
|
var user *User
|
||||||
|
|
||||||
|
query := `SELECT ` + userColumns() + ` FROM users WHERE LOWER(email) = LOWER(?) LIMIT 1`
|
||||||
|
|
||||||
|
err := db.Query(query, func(stmt *sqlite.Stmt) error {
|
||||||
|
user = scanUser(stmt, db)
|
||||||
|
return nil
|
||||||
|
}, email)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("query failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if user == nil {
|
||||||
|
return nil, fmt.Errorf("user not found: %s", email)
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
37
templates/auth/login.html
Normal file
37
templates/auth/login.html
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
{flashhtml}
|
||||||
|
|
||||||
|
<form action="/login" method="post">
|
||||||
|
{csrf}
|
||||||
|
<table width="75%">
|
||||||
|
<tr>
|
||||||
|
<td width="30%">Username:</td>
|
||||||
|
<td><input type="text" size="30" name="username"></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td>Password:</td>
|
||||||
|
<td><input type="password" size="30" name="password"></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td>Remember me?</td>
|
||||||
|
<td><input type="checkbox" name="rememberme" value="yes"> Yes</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td colspan="2"><input type="submit" name="submit" value="Log In"></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td colspan="2">
|
||||||
|
Checking the "Remember Me" option will store your login information in a cookie
|
||||||
|
so you don't have to enter it next time you get online.
|
||||||
|
|
||||||
|
<br><br>
|
||||||
|
|
||||||
|
Want to play? You gotta <a href="/register">register your own character.</a>
|
||||||
|
|
||||||
|
<br><br>
|
||||||
|
|
||||||
|
You may also <a href="/change-password">change your password</a>, or
|
||||||
|
<a href="/lost-password">request a new one</a> if you've lost yours.
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
</form>
|
60
templates/auth/register.html
Normal file
60
templates/auth/register.html
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
{flashhtml}
|
||||||
|
|
||||||
|
<form action="/register" method="post">
|
||||||
|
{csrf}
|
||||||
|
<table width="80%">
|
||||||
|
<tr>
|
||||||
|
<td width="20%">Username:</td>
|
||||||
|
<td>
|
||||||
|
<input type="text" name="username" size="30" maxlength="30">
|
||||||
|
<br>
|
||||||
|
Usernames must be 30 alphanumeric characters or less.
|
||||||
|
<br><br><br>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td>Password:</td>
|
||||||
|
<td><input type="password" name="password1" size="30" maxlength="10"></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td>Verify Password:</td>
|
||||||
|
<td>
|
||||||
|
<input type="password" name="password2" size="30" maxlength="10">
|
||||||
|
<br>
|
||||||
|
Passwords must be 10 alphanumeric characters or less.
|
||||||
|
<br><br><br>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td>Email Address:</td>
|
||||||
|
<td><input type="email" name="email1" size="30" maxlength="100"></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td>Verify Email:</td>
|
||||||
|
<td>
|
||||||
|
<input type="text" name="email2" size="30" maxlength="100">
|
||||||
|
{verifytext}
|
||||||
|
<br><br><br>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td>Character Class:</td>
|
||||||
|
<td>
|
||||||
|
<select name="charclass">
|
||||||
|
<option value="1">{class1name}</option>
|
||||||
|
<option value="2">{class2name}</option>
|
||||||
|
<option value="3">{class3name}</option>
|
||||||
|
</select>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td colspan="2">See <a href="/help">Help</a> for more information about character classes.<br><br></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td colspan="2">
|
||||||
|
<input type="submit" name="submit" value="Submit">
|
||||||
|
<input type="reset" name="reset" value="Reset">
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
</form>
|
Loading…
x
Reference in New Issue
Block a user