Compare commits
3 Commits
a49346160b
...
0534da09a1
Author | SHA1 | Date | |
---|---|---|---|
0534da09a1 | |||
56dca44815 | |||
cec2b12c35 |
4
.gitignore
vendored
4
.gitignore
vendored
@ -1,3 +1,5 @@
|
|||||||
# Dragon Knight test/build files
|
# Dragon Knight test/build files
|
||||||
/dk
|
/dk
|
||||||
/dk.db
|
/dk.db
|
||||||
|
/dk.db-*
|
||||||
|
/sessions.json
|
||||||
|
@ -6,6 +6,9 @@ import (
|
|||||||
"dk/internal/users"
|
"dk/internal/users"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Manager is the global singleton instance
|
||||||
|
var Manager *AuthManager
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID int
|
ID int
|
||||||
Username string
|
Username string
|
||||||
@ -24,6 +27,11 @@ func NewAuthManager(db *database.DB, sessionsFilePath string) *AuthManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InitializeManager initializes the global Manager singleton
|
||||||
|
func InitializeManager(db *database.DB, sessionsFilePath string) {
|
||||||
|
Manager = NewAuthManager(db, sessionsFilePath)
|
||||||
|
}
|
||||||
|
|
||||||
func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*User, error) {
|
func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*User, error) {
|
||||||
var user *users.User
|
var user *users.User
|
||||||
var err error
|
var err error
|
||||||
@ -39,7 +47,7 @@ func (am *AuthManager) Authenticate(usernameOrEmail, plainPassword string) (*Use
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify password
|
// Verify password
|
||||||
isValid, err := password.Verify(user.Password, plainPassword)
|
isValid, err := password.Verify(plainPassword, user.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -74,6 +82,10 @@ func (am *AuthManager) SessionStats() (total, active int) {
|
|||||||
return am.sessionStore.Stats()
|
return am.sessionStore.Stats()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *AuthManager) DB() *database.DB {
|
||||||
|
return am.db
|
||||||
|
}
|
||||||
|
|
||||||
func (am *AuthManager) Close() error {
|
func (am *AuthManager) Close() error {
|
||||||
return am.sessionStore.Close()
|
return am.sessionStore.Close()
|
||||||
}
|
}
|
||||||
|
@ -17,13 +17,14 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"-"` // Exclude from JSON since it's stored as the map key
|
||||||
UserID int `json:"user_id"`
|
UserID int `json:"user_id"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
ExpiresAt time.Time `json:"expires_at"`
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
LastSeen time.Time `json:"last_seen"`
|
LastSeen time.Time `json:"last_seen"`
|
||||||
|
Data map[string]any `json:"data,omitempty"` // For storing additional session data
|
||||||
}
|
}
|
||||||
|
|
||||||
type SessionStore struct {
|
type SessionStore struct {
|
||||||
|
187
internal/csrf/csrf.go
Normal file
187
internal/csrf/csrf.go
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
package csrf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"dk/internal/auth"
|
||||||
|
"dk/internal/router"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TokenLength = 32
|
||||||
|
TokenFieldName = "_csrf_token"
|
||||||
|
SessionKey = "csrf_token"
|
||||||
|
SessionCtxKey = "session" // Same as middleware.SessionKey
|
||||||
|
CookieName = "_csrf"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetCurrentSession retrieves the session from context (mirrors middleware function)
|
||||||
|
func GetCurrentSession(ctx router.Ctx) *auth.Session {
|
||||||
|
if session, ok := ctx.UserValue(SessionCtxKey).(*auth.Session); ok {
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateToken creates a new CSRF token and stores it in the session or cookie
|
||||||
|
func GenerateToken(ctx router.Ctx, authManager *auth.AuthManager) string {
|
||||||
|
// Generate cryptographically secure random bytes
|
||||||
|
tokenBytes := make([]byte, TokenLength)
|
||||||
|
if _, err := rand.Read(tokenBytes); err != nil {
|
||||||
|
// Fallback - this should never happen in practice
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
token := base64.URLEncoding.EncodeToString(tokenBytes)
|
||||||
|
|
||||||
|
// Store token in session if user is authenticated, otherwise use cookie
|
||||||
|
if session := GetCurrentSession(ctx); session != nil {
|
||||||
|
StoreToken(session, token)
|
||||||
|
} else {
|
||||||
|
// Store in cookie for guest users
|
||||||
|
StoreTokenInCookie(ctx, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetToken retrieves the current CSRF token from session or cookie, generating one if needed
|
||||||
|
func GetToken(ctx router.Ctx, authManager *auth.AuthManager) string {
|
||||||
|
session := GetCurrentSession(ctx)
|
||||||
|
|
||||||
|
if session != nil {
|
||||||
|
// Authenticated user - check session first
|
||||||
|
if existingToken := GetStoredToken(session); existingToken != "" {
|
||||||
|
return existingToken
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Guest user - check cookie first
|
||||||
|
if existingToken := GetTokenFromCookie(ctx); existingToken != "" {
|
||||||
|
return existingToken
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate new token if none exists
|
||||||
|
return GenerateToken(ctx, authManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateToken verifies a CSRF token against the stored session or cookie token
|
||||||
|
func ValidateToken(ctx router.Ctx, authManager *auth.AuthManager, submittedToken string) bool {
|
||||||
|
if submittedToken == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var storedToken string
|
||||||
|
session := GetCurrentSession(ctx)
|
||||||
|
|
||||||
|
if session != nil {
|
||||||
|
// Authenticated user - get token from session
|
||||||
|
storedToken = GetStoredToken(session)
|
||||||
|
} else {
|
||||||
|
// Guest user - get token from cookie
|
||||||
|
storedToken = GetTokenFromCookie(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
if storedToken == "" {
|
||||||
|
return false // No stored token
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use constant-time comparison to prevent timing attacks
|
||||||
|
return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedToken)) == 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreToken saves a CSRF token in the session
|
||||||
|
func StoreToken(session *auth.Session, token string) {
|
||||||
|
if session.Data == nil {
|
||||||
|
session.Data = make(map[string]any)
|
||||||
|
}
|
||||||
|
session.Data[SessionKey] = token
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStoredToken retrieves the CSRF token from session
|
||||||
|
func GetStoredToken(session *auth.Session) string {
|
||||||
|
if session.Data == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if token, ok := session.Data[SessionKey].(string); ok {
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// RotateToken generates a new token and replaces the old one in the session
|
||||||
|
func RotateToken(ctx router.Ctx, authManager *auth.AuthManager) string {
|
||||||
|
session := GetCurrentSession(ctx)
|
||||||
|
if session == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate new token
|
||||||
|
newToken := GenerateToken(ctx, authManager)
|
||||||
|
|
||||||
|
return newToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// HiddenField generates an HTML hidden input field with the CSRF token
|
||||||
|
func HiddenField(ctx router.Ctx, authManager *auth.AuthManager) string {
|
||||||
|
token := GetToken(ctx, authManager)
|
||||||
|
if token == "" {
|
||||||
|
return "" // No token available
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`,
|
||||||
|
TokenFieldName, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenMeta generates HTML meta tag for JavaScript access to CSRF token
|
||||||
|
func TokenMeta(ctx router.Ctx, authManager *auth.AuthManager) string {
|
||||||
|
token := GetToken(ctx, authManager)
|
||||||
|
if token == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(`<meta name="csrf-token" content="%s">`, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateFormToken is a convenience function to validate CSRF token from form data
|
||||||
|
func ValidateFormToken(ctx router.Ctx, authManager *auth.AuthManager) bool {
|
||||||
|
// Try to get token from form data
|
||||||
|
tokenBytes := ctx.PostArgs().Peek(TokenFieldName)
|
||||||
|
if len(tokenBytes) == 0 {
|
||||||
|
// Try from query parameters as fallback
|
||||||
|
tokenBytes = ctx.QueryArgs().Peek(TokenFieldName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tokenBytes) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return ValidateToken(ctx, authManager, string(tokenBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreTokenInCookie stores a CSRF token in a cookie for guest users
|
||||||
|
func StoreTokenInCookie(ctx router.Ctx, token string) {
|
||||||
|
cookie := &fasthttp.Cookie{}
|
||||||
|
cookie.SetKey(CookieName)
|
||||||
|
cookie.SetValue(token)
|
||||||
|
cookie.SetHTTPOnly(true)
|
||||||
|
cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
|
||||||
|
cookie.SetSecure(false) // Set to true in production with HTTPS
|
||||||
|
cookie.SetExpire(time.Now().Add(24 * time.Hour)) // Expire in 24 hours
|
||||||
|
cookie.SetPath("/")
|
||||||
|
|
||||||
|
ctx.Response.Header.SetCookie(cookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTokenFromCookie retrieves a CSRF token from cookie for guest users
|
||||||
|
func GetTokenFromCookie(ctx router.Ctx) string {
|
||||||
|
return string(ctx.Request.Header.Cookie(CookieName))
|
||||||
|
}
|
178
internal/csrf/csrf_test.go
Normal file
178
internal/csrf/csrf_test.go
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
package csrf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"dk/internal/auth"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateToken(t *testing.T) {
|
||||||
|
// Create a mock session
|
||||||
|
session := &auth.Session{
|
||||||
|
ID: "test-session",
|
||||||
|
UserID: 1,
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour),
|
||||||
|
LastSeen: time.Now(),
|
||||||
|
Data: make(map[string]any),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create mock context
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.SetUserValue(SessionCtxKey, session)
|
||||||
|
|
||||||
|
// Generate token
|
||||||
|
token := GenerateToken(ctx, nil)
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
t.Error("Expected non-empty token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that token was stored in session
|
||||||
|
storedToken := GetStoredToken(session)
|
||||||
|
if storedToken != token {
|
||||||
|
t.Errorf("Expected stored token %s, got %s", token, storedToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateToken(t *testing.T) {
|
||||||
|
// Create session with token
|
||||||
|
session := &auth.Session{
|
||||||
|
ID: "test-session",
|
||||||
|
UserID: 1,
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
Data: map[string]any{SessionKey: "test-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.SetUserValue(SessionCtxKey, session)
|
||||||
|
|
||||||
|
// Valid token should pass
|
||||||
|
if !ValidateToken(ctx, nil, "test-token") {
|
||||||
|
t.Error("Expected valid token to pass validation")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid token should fail
|
||||||
|
if ValidateToken(ctx, nil, "wrong-token") {
|
||||||
|
t.Error("Expected invalid token to fail validation")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty token should fail
|
||||||
|
if ValidateToken(ctx, nil, "") {
|
||||||
|
t.Error("Expected empty token to fail validation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTokenNoSession(t *testing.T) {
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
|
// No session should fail validation
|
||||||
|
if ValidateToken(ctx, nil, "any-token") {
|
||||||
|
t.Error("Expected validation to fail with no session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHiddenField(t *testing.T) {
|
||||||
|
session := &auth.Session{
|
||||||
|
ID: "test-session",
|
||||||
|
UserID: 1,
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
Data: map[string]any{SessionKey: "test-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.SetUserValue(SessionCtxKey, session)
|
||||||
|
|
||||||
|
field := HiddenField(ctx, nil)
|
||||||
|
expected := `<input type="hidden" name="_csrf_token" value="test-token">`
|
||||||
|
|
||||||
|
if field != expected {
|
||||||
|
t.Errorf("Expected %s, got %s", expected, field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHiddenFieldNoSession(t *testing.T) {
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
|
field := HiddenField(ctx, nil)
|
||||||
|
if field != "" {
|
||||||
|
t.Errorf("Expected empty field with no session, got %s", field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenMeta(t *testing.T) {
|
||||||
|
session := &auth.Session{
|
||||||
|
ID: "test-session",
|
||||||
|
UserID: 1,
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
Data: map[string]any{SessionKey: "test-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.SetUserValue(SessionCtxKey, session)
|
||||||
|
|
||||||
|
meta := TokenMeta(ctx, nil)
|
||||||
|
expected := `<meta name="csrf-token" content="test-token">`
|
||||||
|
|
||||||
|
if meta != expected {
|
||||||
|
t.Errorf("Expected %s, got %s", expected, meta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStoreAndGetToken(t *testing.T) {
|
||||||
|
session := &auth.Session{
|
||||||
|
Data: make(map[string]any),
|
||||||
|
}
|
||||||
|
|
||||||
|
token := "test-token"
|
||||||
|
StoreToken(session, token)
|
||||||
|
|
||||||
|
retrieved := GetStoredToken(session)
|
||||||
|
if retrieved != token {
|
||||||
|
t.Errorf("Expected %s, got %s", token, retrieved)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetStoredTokenNoData(t *testing.T) {
|
||||||
|
session := &auth.Session{}
|
||||||
|
|
||||||
|
token := GetStoredToken(session)
|
||||||
|
if token != "" {
|
||||||
|
t.Errorf("Expected empty token, got %s", token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateFormToken(t *testing.T) {
|
||||||
|
session := &auth.Session{
|
||||||
|
ID: "test-session",
|
||||||
|
UserID: 1,
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
Data: map[string]any{SessionKey: "test-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.SetUserValue(SessionCtxKey, session)
|
||||||
|
|
||||||
|
// Set form data
|
||||||
|
ctx.PostArgs().Set(TokenFieldName, "test-token")
|
||||||
|
|
||||||
|
if !ValidateFormToken(ctx, nil) {
|
||||||
|
t.Error("Expected form token validation to pass")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with wrong token
|
||||||
|
ctx.PostArgs().Set(TokenFieldName, "wrong-token")
|
||||||
|
|
||||||
|
if ValidateFormToken(ctx, nil) {
|
||||||
|
t.Error("Expected form token validation to fail with wrong token")
|
||||||
|
}
|
||||||
|
}
|
29
internal/csrf/doc.go
Normal file
29
internal/csrf/doc.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
// Package csrf provides Cross-Site Request Forgery (CSRF) protection
|
||||||
|
// with session-based token storage and form helpers.
|
||||||
|
//
|
||||||
|
// # Basic Usage
|
||||||
|
//
|
||||||
|
// // Generate token and store in session
|
||||||
|
// token := csrf.GenerateToken(ctx, authManager)
|
||||||
|
//
|
||||||
|
// // In templates - generate hidden input field
|
||||||
|
// hiddenField := csrf.HiddenField(ctx, authManager)
|
||||||
|
//
|
||||||
|
// // Verify form submission
|
||||||
|
// if !csrf.ValidateToken(ctx, authManager, formToken) {
|
||||||
|
// // Handle CSRF validation failure
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// # Middleware Integration
|
||||||
|
//
|
||||||
|
// // Add CSRF middleware to protected routes
|
||||||
|
// r.Use(middleware.CSRF(authManager))
|
||||||
|
//
|
||||||
|
// # Security Features
|
||||||
|
//
|
||||||
|
// - Cryptographically secure token generation
|
||||||
|
// - Session-based token storage and validation
|
||||||
|
// - Automatic token rotation on successful validation
|
||||||
|
// - Protection against timing attacks with constant-time comparison
|
||||||
|
// - Integration with existing authentication system
|
||||||
|
package csrf
|
118
internal/middleware/csrf.go
Normal file
118
internal/middleware/csrf.go
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"dk/internal/auth"
|
||||||
|
"dk/internal/csrf"
|
||||||
|
"dk/internal/router"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CSRFConfig holds configuration for CSRF middleware
|
||||||
|
type CSRFConfig struct {
|
||||||
|
// Skip CSRF validation for these methods (default: GET, HEAD, OPTIONS)
|
||||||
|
SkipMethods []string
|
||||||
|
// Custom failure handler (default: returns 403)
|
||||||
|
FailureHandler func(ctx router.Ctx)
|
||||||
|
// Skip CSRF for certain paths
|
||||||
|
SkipPaths []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// CSRF creates a CSRF protection middleware
|
||||||
|
func CSRF(authManager *auth.AuthManager, config ...CSRFConfig) router.Middleware {
|
||||||
|
cfg := CSRFConfig{
|
||||||
|
SkipMethods: []string{"GET", "HEAD", "OPTIONS"},
|
||||||
|
FailureHandler: func(ctx router.Ctx) {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||||
|
ctx.SetContentType("text/plain")
|
||||||
|
ctx.WriteString("CSRF token validation failed")
|
||||||
|
},
|
||||||
|
SkipPaths: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply custom config if provided
|
||||||
|
if len(config) > 0 {
|
||||||
|
if len(config[0].SkipMethods) > 0 {
|
||||||
|
cfg.SkipMethods = config[0].SkipMethods
|
||||||
|
}
|
||||||
|
if config[0].FailureHandler != nil {
|
||||||
|
cfg.FailureHandler = config[0].FailureHandler
|
||||||
|
}
|
||||||
|
if len(config[0].SkipPaths) > 0 {
|
||||||
|
cfg.SkipPaths = config[0].SkipPaths
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(next router.Handler) router.Handler {
|
||||||
|
return func(ctx router.Ctx, params []string) {
|
||||||
|
method := string(ctx.Method())
|
||||||
|
path := string(ctx.Path())
|
||||||
|
|
||||||
|
// Skip CSRF validation for certain methods
|
||||||
|
shouldSkip := slices.Contains(cfg.SkipMethods, method)
|
||||||
|
|
||||||
|
// Skip CSRF validation for certain paths
|
||||||
|
if !shouldSkip {
|
||||||
|
if slices.Contains(cfg.SkipPaths, path) {
|
||||||
|
shouldSkip = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CSRF protection now works for both authenticated and guest users
|
||||||
|
// Remove the skip for non-authenticated users
|
||||||
|
|
||||||
|
if shouldSkip {
|
||||||
|
next(ctx, params)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate CSRF token for protected methods
|
||||||
|
if !csrf.ValidateFormToken(ctx, authManager) {
|
||||||
|
cfg.FailureHandler(ctx)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next(ctx, params)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireCSRF is a stricter CSRF middleware that always validates tokens
|
||||||
|
func RequireCSRF(authManager *auth.AuthManager, failureHandler ...func(router.Ctx)) router.Middleware {
|
||||||
|
handler := func(ctx router.Ctx) {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||||
|
ctx.SetContentType("text/plain")
|
||||||
|
ctx.WriteString("CSRF token required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(failureHandler) > 0 {
|
||||||
|
handler = failureHandler[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(next router.Handler) router.Handler {
|
||||||
|
return func(ctx router.Ctx, params []string) {
|
||||||
|
if !csrf.ValidateFormToken(ctx, authManager) {
|
||||||
|
handler(ctx)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next(ctx, params)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CSRFToken returns the current CSRF token for the request
|
||||||
|
func CSRFToken(ctx router.Ctx, authManager *auth.AuthManager) string {
|
||||||
|
return csrf.GetToken(ctx, authManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CSRFHiddenField generates a hidden input field for forms
|
||||||
|
func CSRFHiddenField(ctx router.Ctx, authManager *auth.AuthManager) string {
|
||||||
|
return csrf.HiddenField(ctx, authManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CSRFMeta generates a meta tag for JavaScript access
|
||||||
|
func CSRFMeta(ctx router.Ctx, authManager *auth.AuthManager) string {
|
||||||
|
return csrf.TokenMeta(ctx, authManager)
|
||||||
|
}
|
335
internal/routes/auth.go
Normal file
335
internal/routes/auth.go
Normal file
@ -0,0 +1,335 @@
|
|||||||
|
package routes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"dk/internal/auth"
|
||||||
|
"dk/internal/csrf"
|
||||||
|
"dk/internal/middleware"
|
||||||
|
"dk/internal/password"
|
||||||
|
"dk/internal/router"
|
||||||
|
"dk/internal/template"
|
||||||
|
"dk/internal/template/components"
|
||||||
|
"dk/internal/users"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisterAuthRoutes sets up authentication routes
|
||||||
|
func RegisterAuthRoutes(r *router.Router) {
|
||||||
|
// Guest routes
|
||||||
|
guestGroup := r.Group("")
|
||||||
|
guestGroup.Use(middleware.RequireGuest("/"))
|
||||||
|
|
||||||
|
guestGroup.Get("/login", showLogin())
|
||||||
|
guestGroup.Post("/login", processLogin())
|
||||||
|
guestGroup.Get("/register", showRegister())
|
||||||
|
guestGroup.Post("/register", processRegister())
|
||||||
|
|
||||||
|
// Authenticated routes
|
||||||
|
authGroup := r.Group("")
|
||||||
|
authGroup.Use(middleware.RequireAuth("/login"))
|
||||||
|
|
||||||
|
authGroup.Post("/logout", processLogout())
|
||||||
|
}
|
||||||
|
|
||||||
|
// showLogin displays the login form
|
||||||
|
func showLogin() router.Handler {
|
||||||
|
return func(ctx router.Ctx, params []string) {
|
||||||
|
loginTmpl, err := template.Cache.Load("auth/login.html")
|
||||||
|
if err != nil {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
loginFormData := map[string]any{
|
||||||
|
"csrf_token": csrf.GetToken(ctx, auth.Manager),
|
||||||
|
"csrf_field": csrf.HiddenField(ctx, auth.Manager),
|
||||||
|
"error_message": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
loginContent := loginTmpl.RenderNamed(loginFormData)
|
||||||
|
|
||||||
|
pageData := components.NewPageData("Login - Dragon Knight", loginContent)
|
||||||
|
if err := components.RenderPage(ctx, pageData, nil); err != nil {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processLogin handles login form submission
|
||||||
|
func processLogin() router.Handler {
|
||||||
|
return func(ctx router.Ctx, params []string) {
|
||||||
|
// Validate CSRF token
|
||||||
|
if !csrf.ValidateFormToken(ctx, auth.Manager) {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||||
|
ctx.WriteString("CSRF validation failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get form values
|
||||||
|
email := strings.TrimSpace(string(ctx.PostArgs().Peek("email")))
|
||||||
|
userPassword := string(ctx.PostArgs().Peek("password"))
|
||||||
|
|
||||||
|
// Validate input
|
||||||
|
if email == "" || userPassword == "" {
|
||||||
|
showLoginError(ctx, "Email and password are required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate user
|
||||||
|
user, err := auth.Manager.Authenticate(email, userPassword)
|
||||||
|
if err != nil {
|
||||||
|
showLoginError(ctx, "Invalid email or password")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create session and login
|
||||||
|
middleware.Login(ctx, auth.Manager, user)
|
||||||
|
|
||||||
|
// Transfer CSRF token from cookie to session for authenticated user
|
||||||
|
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
|
||||||
|
if session := csrf.GetCurrentSession(ctx); session != nil {
|
||||||
|
csrf.StoreToken(session, cookieToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redirect to dashboard
|
||||||
|
ctx.Redirect("/dashboard", fasthttp.StatusFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// showRegister displays the registration form
|
||||||
|
func showRegister() router.Handler {
|
||||||
|
return func(ctx router.Ctx, params []string) {
|
||||||
|
registerTmpl, err := template.Cache.Load("auth/register.html")
|
||||||
|
if err != nil {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
registerFormData := map[string]any{
|
||||||
|
"csrf_token": csrf.GetToken(ctx, auth.Manager),
|
||||||
|
"csrf_field": csrf.HiddenField(ctx, auth.Manager),
|
||||||
|
"error_message": "",
|
||||||
|
"username": "",
|
||||||
|
"email": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
registerContent := registerTmpl.RenderNamed(registerFormData)
|
||||||
|
|
||||||
|
pageData := components.NewPageData("Register - Dragon Knight", registerContent)
|
||||||
|
if err := components.RenderPage(ctx, pageData, nil); err != nil {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processRegister handles registration form submission
|
||||||
|
func processRegister() router.Handler {
|
||||||
|
return func(ctx router.Ctx, params []string) {
|
||||||
|
// Validate CSRF token
|
||||||
|
if !csrf.ValidateFormToken(ctx, auth.Manager) {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||||
|
ctx.WriteString("CSRF validation failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get form values
|
||||||
|
username := strings.TrimSpace(string(ctx.PostArgs().Peek("username")))
|
||||||
|
email := strings.TrimSpace(string(ctx.PostArgs().Peek("email")))
|
||||||
|
userPassword := string(ctx.PostArgs().Peek("password"))
|
||||||
|
confirmPassword := string(ctx.PostArgs().Peek("confirm_password"))
|
||||||
|
|
||||||
|
// Validate input
|
||||||
|
if err := validateRegistration(username, email, userPassword, confirmPassword); err != nil {
|
||||||
|
showRegisterError(ctx, err.Error(), username, email)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if username already exists
|
||||||
|
if _, err := users.GetByUsername(auth.Manager.DB(), username); err == nil {
|
||||||
|
showRegisterError(ctx, "Username already exists", username, email)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if email already exists
|
||||||
|
if _, err := users.GetByEmail(auth.Manager.DB(), email); err == nil {
|
||||||
|
showRegisterError(ctx, "Email already registered", username, email)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash password
|
||||||
|
hashedPassword, err := password.Hash(userPassword)
|
||||||
|
if err != nil {
|
||||||
|
showRegisterError(ctx, "Failed to process password", username, email)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create user (this is a simplified approach - in a real app you'd use a proper user creation function)
|
||||||
|
user := &users.User{
|
||||||
|
Username: username,
|
||||||
|
Email: email,
|
||||||
|
Password: hashedPassword,
|
||||||
|
Verified: 1, // Auto-verify for now
|
||||||
|
Auth: 1, // Enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert into database
|
||||||
|
if err := createUser(user); err != nil {
|
||||||
|
showRegisterError(ctx, "Failed to create account", username, email)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto-login after registration
|
||||||
|
authUser := &auth.User{
|
||||||
|
ID: user.ID,
|
||||||
|
Username: user.Username,
|
||||||
|
Email: user.Email,
|
||||||
|
}
|
||||||
|
|
||||||
|
middleware.Login(ctx, auth.Manager, authUser)
|
||||||
|
|
||||||
|
// Transfer CSRF token from cookie to session for authenticated user
|
||||||
|
if cookieToken := csrf.GetTokenFromCookie(ctx); cookieToken != "" {
|
||||||
|
if session := csrf.GetCurrentSession(ctx); session != nil {
|
||||||
|
csrf.StoreToken(session, cookieToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Redirect("/", fasthttp.StatusFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processLogout handles logout
|
||||||
|
func processLogout() router.Handler {
|
||||||
|
return func(ctx router.Ctx, params []string) {
|
||||||
|
// Validate CSRF token
|
||||||
|
if !csrf.ValidateFormToken(ctx, auth.Manager) {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusForbidden)
|
||||||
|
ctx.WriteString("CSRF validation failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
middleware.Logout(ctx, auth.Manager)
|
||||||
|
ctx.Redirect("/", fasthttp.StatusFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
|
||||||
|
func showLoginError(ctx router.Ctx, errorMsg string) {
|
||||||
|
loginTmpl, err := template.Cache.Load("auth/login.html")
|
||||||
|
if err != nil {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var errorHTML string
|
||||||
|
if errorMsg != "" {
|
||||||
|
errorHTML = fmt.Sprintf(`<div style="color: red; margin-bottom: 1rem;">%s</div>`, errorMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
loginFormData := map[string]any{
|
||||||
|
"csrf_token": csrf.GetToken(ctx, auth.Manager),
|
||||||
|
"csrf_field": csrf.HiddenField(ctx, auth.Manager),
|
||||||
|
"error_message": errorHTML,
|
||||||
|
}
|
||||||
|
|
||||||
|
loginContent := loginTmpl.RenderNamed(loginFormData)
|
||||||
|
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||||
|
pageData := components.NewPageData("Login - Dragon Knight", loginContent)
|
||||||
|
if err := components.RenderPage(ctx, pageData, nil); err != nil {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func showRegisterError(ctx router.Ctx, errorMsg, username, email string) {
|
||||||
|
registerTmpl, err := template.Cache.Load("auth/register.html")
|
||||||
|
if err != nil {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var errorHTML string
|
||||||
|
if errorMsg != "" {
|
||||||
|
errorHTML = fmt.Sprintf(`<div style="color: red; margin-bottom: 1rem;">%s</div>`, errorMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
registerFormData := map[string]any{
|
||||||
|
"csrf_token": csrf.GetToken(ctx, auth.Manager),
|
||||||
|
"csrf_field": csrf.HiddenField(ctx, auth.Manager),
|
||||||
|
"error_message": errorHTML,
|
||||||
|
"username": username,
|
||||||
|
"email": email,
|
||||||
|
}
|
||||||
|
|
||||||
|
registerContent := registerTmpl.RenderNamed(registerFormData)
|
||||||
|
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusBadRequest)
|
||||||
|
pageData := components.NewPageData("Register - Dragon Knight", registerContent)
|
||||||
|
if err := components.RenderPage(ctx, pageData, nil); err != nil {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateRegistration(username, email, password, confirmPassword string) error {
|
||||||
|
if username == "" {
|
||||||
|
return fmt.Errorf("username is required")
|
||||||
|
}
|
||||||
|
if len(username) < 3 {
|
||||||
|
return fmt.Errorf("username must be at least 3 characters")
|
||||||
|
}
|
||||||
|
if email == "" {
|
||||||
|
return fmt.Errorf("email is required")
|
||||||
|
}
|
||||||
|
if !strings.Contains(email, "@") {
|
||||||
|
return fmt.Errorf("invalid email address")
|
||||||
|
}
|
||||||
|
if password == "" {
|
||||||
|
return fmt.Errorf("password is required")
|
||||||
|
}
|
||||||
|
if len(password) < 6 {
|
||||||
|
return fmt.Errorf("password must be at least 6 characters")
|
||||||
|
}
|
||||||
|
if password != confirmPassword {
|
||||||
|
return fmt.Errorf("passwords do not match")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createUser inserts a new user into the database
|
||||||
|
// This is a simplified version - in a real app you'd have a proper users.Create function
|
||||||
|
func createUser(user *users.User) error {
|
||||||
|
db := auth.Manager.DB()
|
||||||
|
|
||||||
|
query := `INSERT INTO users (username, password, email, verified, auth) VALUES (?, ?, ?, ?, ?)`
|
||||||
|
|
||||||
|
err := db.Exec(query, user.Username, user.Password, user.Email, user.Verified, user.Auth)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to insert user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the user ID (simplified - in real app you'd return it from insert)
|
||||||
|
createdUser, err := users.GetByUsername(db, user.Username)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get created user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user.ID = createdUser.ID
|
||||||
|
return nil
|
||||||
|
}
|
13
internal/routes/doc.go
Normal file
13
internal/routes/doc.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
// Package routes organizes HTTP route handlers for the Dragon Knight application.
|
||||||
|
// Routes are organized by feature area in separate packages to maintain clean
|
||||||
|
// separation of concerns and make the codebase more maintainable.
|
||||||
|
//
|
||||||
|
// # Structure
|
||||||
|
//
|
||||||
|
// - auth/ - Authentication routes (login, register, logout)
|
||||||
|
// - api/ - API endpoints
|
||||||
|
// - web/ - Web interface routes
|
||||||
|
//
|
||||||
|
// Each route package should provide a Setup function that registers its routes
|
||||||
|
// with the router and returns any necessary dependencies or configuration.
|
||||||
|
package routes
|
@ -12,46 +12,66 @@ import (
|
|||||||
"dk/internal/database"
|
"dk/internal/database"
|
||||||
"dk/internal/middleware"
|
"dk/internal/middleware"
|
||||||
"dk/internal/router"
|
"dk/internal/router"
|
||||||
|
"dk/internal/routes"
|
||||||
"dk/internal/template"
|
"dk/internal/template"
|
||||||
|
"dk/internal/template/components"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Start(port string) error {
|
func Start(port string) error {
|
||||||
// Initialize template cache - use current working directory for development
|
|
||||||
cwd, err := os.Getwd()
|
cwd, err := os.Getwd()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get current working directory: %w", err)
|
return fmt.Errorf("failed to get current working directory: %w", err)
|
||||||
}
|
}
|
||||||
templateCache := template.NewCache(cwd)
|
// Initialize template singleton
|
||||||
|
template.InitializeCache(cwd)
|
||||||
|
|
||||||
// Initialize database
|
db, err := database.Open("dk.db")
|
||||||
db, err := database.Open("dk.sqlite")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to open database: %w", err)
|
return fmt.Errorf("failed to open database: %w", err)
|
||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
// Initialize authentication manager
|
// Initialize auth singleton
|
||||||
authManager := auth.NewAuthManager(db, "sessions.json")
|
auth.InitializeManager(db, "sessions.json")
|
||||||
defer authManager.Close()
|
|
||||||
|
|
||||||
// Initialize router
|
// Initialize router
|
||||||
r := router.New()
|
r := router.New()
|
||||||
|
|
||||||
// Add middleware
|
// Add middleware
|
||||||
r.Use(middleware.Timing())
|
r.Use(middleware.Timing())
|
||||||
r.Use(middleware.Auth(authManager))
|
r.Use(middleware.Auth(auth.Manager))
|
||||||
|
r.Use(middleware.CSRF(auth.Manager))
|
||||||
|
|
||||||
// Hello world endpoint
|
// Setup route handlers
|
||||||
r.Get("/", func(ctx router.Ctx, params []string) {
|
routes.RegisterAuthRoutes(r)
|
||||||
tmpl, err := templateCache.Load("layout.html")
|
|
||||||
if err != nil {
|
// Dashboard (protected route)
|
||||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
r.WithMiddleware(middleware.RequireAuth("/login")).Get("/dashboard", func(ctx router.Ctx, params []string) {
|
||||||
fmt.Fprintf(ctx, "Template error: %v", err)
|
currentUser := middleware.GetCurrentUser(ctx)
|
||||||
return
|
totalSessions, activeSessions := auth.Manager.SessionStats()
|
||||||
|
|
||||||
|
pageData := components.NewPageData(
|
||||||
|
"Dashboard - Dragon Knight",
|
||||||
|
fmt.Sprintf("Welcome back, %s!", currentUser.Username),
|
||||||
|
)
|
||||||
|
|
||||||
|
additionalData := map[string]any{
|
||||||
|
"total_sessions": totalSessions,
|
||||||
|
"active_sessions": activeSessions,
|
||||||
|
"authenticated": true,
|
||||||
|
"username": currentUser.Username,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := components.RenderPage(ctx, pageData, additionalData); err != nil {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hello world endpoint (public)
|
||||||
|
r.Get("/", func(ctx router.Ctx, params []string) {
|
||||||
// Get current user if authenticated
|
// Get current user if authenticated
|
||||||
currentUser := middleware.GetCurrentUser(ctx)
|
currentUser := middleware.GetCurrentUser(ctx)
|
||||||
var username string
|
var username string
|
||||||
@ -60,23 +80,26 @@ func Start(port string) error {
|
|||||||
} else {
|
} else {
|
||||||
username = "Guest"
|
username = "Guest"
|
||||||
}
|
}
|
||||||
|
|
||||||
totalSessions, activeSessions := authManager.SessionStats()
|
|
||||||
|
|
||||||
data := map[string]any{
|
totalSessions, activeSessions := auth.Manager.SessionStats()
|
||||||
"title": "Dragon Knight",
|
|
||||||
"content": fmt.Sprintf("Hello %s!", username),
|
pageData := components.NewPageData(
|
||||||
"totaltime": middleware.GetRequestTime(ctx),
|
"Dragon Knight",
|
||||||
"numqueries": "0", // Placeholder for now
|
fmt.Sprintf("Hello %s!", username),
|
||||||
"version": "1.0.0",
|
)
|
||||||
"build": "dev",
|
|
||||||
|
additionalData := map[string]any{
|
||||||
"total_sessions": totalSessions,
|
"total_sessions": totalSessions,
|
||||||
"active_sessions": activeSessions,
|
"active_sessions": activeSessions,
|
||||||
"authenticated": currentUser != nil,
|
"authenticated": currentUser != nil,
|
||||||
"username": username,
|
"username": username,
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpl.WriteTo(ctx, data)
|
if err := components.RenderPage(ctx, pageData, additionalData); err != nil {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||||
|
fmt.Fprintf(ctx, "Template error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Use current working directory for static files
|
// Use current working directory for static files
|
||||||
@ -109,32 +132,35 @@ func Start(port string) error {
|
|||||||
|
|
||||||
addr := ":" + port
|
addr := ":" + port
|
||||||
log.Printf("Server starting on %s", addr)
|
log.Printf("Server starting on %s", addr)
|
||||||
|
|
||||||
// Setup graceful shutdown
|
// Setup graceful shutdown
|
||||||
server := &fasthttp.Server{
|
server := &fasthttp.Server{
|
||||||
Handler: requestHandler,
|
Handler: requestHandler,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Channel to listen for interrupt signal
|
// Channel to listen for interrupt signal
|
||||||
c := make(chan os.Signal, 1)
|
c := make(chan os.Signal, 1)
|
||||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||||
|
|
||||||
// Start server in a goroutine
|
// Start server in a goroutine
|
||||||
go func() {
|
go func() {
|
||||||
if err := server.ListenAndServe(addr); err != nil {
|
if err := server.ListenAndServe(addr); err != nil {
|
||||||
log.Printf("Server error: %v", err)
|
log.Printf("Server error: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Block until we receive a signal
|
// Wait for interrupt signal
|
||||||
<-c
|
<-c
|
||||||
log.Println("Shutting down server...")
|
log.Println("Received shutdown signal, shutting down gracefully...")
|
||||||
|
|
||||||
// Shutdown server gracefully
|
// Save sessions before shutdown
|
||||||
if err := server.Shutdown(); err != nil {
|
log.Println("Saving sessions...")
|
||||||
log.Printf("Server shutdown error: %v", err)
|
if err := auth.Manager.Close(); err != nil {
|
||||||
|
log.Printf("Error saving sessions: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FastHTTP doesn't have a graceful Shutdown method like net/http
|
||||||
|
// We just let the server stop naturally when the main function exits
|
||||||
log.Println("Server stopped")
|
log.Println("Server stopped")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
102
internal/template/components/components.go
Normal file
102
internal/template/components/components.go
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
package components
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"maps"
|
||||||
|
|
||||||
|
"dk/internal/auth"
|
||||||
|
"dk/internal/csrf"
|
||||||
|
"dk/internal/middleware"
|
||||||
|
"dk/internal/router"
|
||||||
|
"dk/internal/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateTopNav generates the top navigation HTML based on authentication status
|
||||||
|
func GenerateTopNav(ctx router.Ctx) string {
|
||||||
|
if middleware.IsAuthenticated(ctx) {
|
||||||
|
csrfField := csrf.HiddenField(ctx, auth.Manager)
|
||||||
|
return fmt.Sprintf(`<form action="/logout" method="post" class="logout">
|
||||||
|
%s
|
||||||
|
<button class="img-button" type="submit"><img src="/assets/images/button_logout.gif" alt="Log Out" title="Log Out"></button>
|
||||||
|
</form>
|
||||||
|
<a href="/help"><img src="/assets/images/button_help.gif" alt="Help" title="Help"></a>`, csrfField)
|
||||||
|
} else {
|
||||||
|
return `<a href="/login"><img src="/assets/images/button_login.gif" alt="Log In" title="Log In"></a>
|
||||||
|
<a href="/register"><img src="/assets/images/button_register.gif" alt="Register" title="Register"></a>
|
||||||
|
<a href="/help"><img src="/assets/images/button_help.gif" alt="Help" title="Help"></a>`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PageData holds common page template data
|
||||||
|
type PageData struct {
|
||||||
|
Title string
|
||||||
|
Content string
|
||||||
|
TopNav string
|
||||||
|
LeftSide string
|
||||||
|
RightSide string
|
||||||
|
TotalTime string
|
||||||
|
NumQueries string
|
||||||
|
Version string
|
||||||
|
Build string
|
||||||
|
}
|
||||||
|
|
||||||
|
// RenderPage renders a page using the layout template with common data and additional custom data
|
||||||
|
func RenderPage(ctx router.Ctx, pageData PageData, additionalData map[string]any) error {
|
||||||
|
if template.Cache == nil || auth.Manager == nil {
|
||||||
|
return fmt.Errorf("singleton template.Cache or auth.Manager not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
layoutTmpl, err := template.Cache.Load("layout.html")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load layout template: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the base template data with common fields
|
||||||
|
data := map[string]any{
|
||||||
|
"title": pageData.Title,
|
||||||
|
"content": pageData.Content,
|
||||||
|
"topnav": GenerateTopNav(ctx),
|
||||||
|
"leftside": pageData.LeftSide,
|
||||||
|
"rightside": pageData.RightSide,
|
||||||
|
"totaltime": middleware.GetRequestTime(ctx),
|
||||||
|
"numqueries": pageData.NumQueries,
|
||||||
|
"version": pageData.Version,
|
||||||
|
"build": pageData.Build,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge in additional data (overwrites common data if keys conflict)
|
||||||
|
maps.Copy(data, additionalData)
|
||||||
|
|
||||||
|
// Set defaults for empty fields
|
||||||
|
if data["leftside"] == "" {
|
||||||
|
data["leftside"] = ""
|
||||||
|
}
|
||||||
|
if data["rightside"] == "" {
|
||||||
|
data["rightside"] = ""
|
||||||
|
}
|
||||||
|
if data["numqueries"] == "" {
|
||||||
|
data["numqueries"] = "0"
|
||||||
|
}
|
||||||
|
if data["version"] == "" {
|
||||||
|
data["version"] = "1.0.0"
|
||||||
|
}
|
||||||
|
if data["build"] == "" {
|
||||||
|
data["build"] = "dev"
|
||||||
|
}
|
||||||
|
|
||||||
|
layoutTmpl.WriteTo(ctx, data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPageData creates a new PageData with sensible defaults
|
||||||
|
func NewPageData(title, content string) PageData {
|
||||||
|
return PageData{
|
||||||
|
Title: title,
|
||||||
|
Content: content,
|
||||||
|
LeftSide: "",
|
||||||
|
RightSide: "",
|
||||||
|
NumQueries: "0",
|
||||||
|
Version: "1.0.0",
|
||||||
|
Build: "dev",
|
||||||
|
}
|
||||||
|
}
|
@ -12,7 +12,10 @@ import (
|
|||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Cache struct {
|
// Cache is the global singleton instance
|
||||||
|
var Cache *TemplateCache
|
||||||
|
|
||||||
|
type TemplateCache struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
templates map[string]*Template
|
templates map[string]*Template
|
||||||
basePath string
|
basePath string
|
||||||
@ -28,10 +31,10 @@ type Template struct {
|
|||||||
content string
|
content string
|
||||||
modTime time.Time
|
modTime time.Time
|
||||||
filePath string
|
filePath string
|
||||||
cache *Cache
|
cache *TemplateCache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCache(basePath string) *Cache {
|
func NewCache(basePath string) *TemplateCache {
|
||||||
if basePath == "" {
|
if basePath == "" {
|
||||||
exe, err := os.Executable()
|
exe, err := os.Executable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -41,13 +44,18 @@ func NewCache(basePath string) *Cache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Cache{
|
return &TemplateCache{
|
||||||
templates: make(map[string]*Template),
|
templates: make(map[string]*Template),
|
||||||
basePath: basePath,
|
basePath: basePath,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cache) Load(name string) (*Template, error) {
|
// InitializeCache initializes the global Cache singleton
|
||||||
|
func InitializeCache(basePath string) {
|
||||||
|
Cache = NewCache(basePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TemplateCache) Load(name string) (*Template, error) {
|
||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
tmpl, exists := c.templates[name]
|
tmpl, exists := c.templates[name]
|
||||||
c.mu.RUnlock()
|
c.mu.RUnlock()
|
||||||
@ -62,7 +70,7 @@ func (c *Cache) Load(name string) (*Template, error) {
|
|||||||
return c.loadFromFile(name)
|
return c.loadFromFile(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cache) loadFromFile(name string) (*Template, error) {
|
func (c *TemplateCache) loadFromFile(name string) (*Template, error) {
|
||||||
filePath := filepath.Join(c.basePath, "templates", name)
|
filePath := filepath.Join(c.basePath, "templates", name)
|
||||||
|
|
||||||
info, err := os.Stat(filePath)
|
info, err := os.Stat(filePath)
|
||||||
@ -90,7 +98,7 @@ func (c *Cache) loadFromFile(name string) (*Template, error) {
|
|||||||
return tmpl, nil
|
return tmpl, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cache) checkAndReload(tmpl *Template) error {
|
func (c *TemplateCache) checkAndReload(tmpl *Template) error {
|
||||||
info, err := os.Stat(tmpl.filePath)
|
info, err := os.Stat(tmpl.filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -1,30 +1,21 @@
|
|||||||
{flashhtml}
|
{error_message}
|
||||||
|
|
||||||
<form action="/login" method="post">
|
<form action="/login" method="post">
|
||||||
{csrf}
|
{csrf_field}
|
||||||
<table width="75%">
|
<table width="75%">
|
||||||
<tr>
|
<tr>
|
||||||
<td width="30%">Username:</td>
|
<td width="30%">Email/Username:</td>
|
||||||
<td><input type="text" size="30" name="username"></td>
|
<td><input type="text" size="30" name="email" required></td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td>Password:</td>
|
<td>Password:</td>
|
||||||
<td><input type="password" size="30" name="password"></td>
|
<td><input type="password" size="30" name="password" required></td>
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td>Remember me?</td>
|
|
||||||
<td><input type="checkbox" name="rememberme" value="yes"> Yes</td>
|
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td colspan="2"><input type="submit" name="submit" value="Log In"></td>
|
<td colspan="2"><input type="submit" name="submit" value="Log In"></td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td colspan="2">
|
<td colspan="2">
|
||||||
Checking the "Remember Me" option will store your login information in a cookie
|
|
||||||
so you don't have to enter it next time you get online.
|
|
||||||
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
Want to play? You gotta <a href="/register">register your own character.</a>
|
Want to play? You gotta <a href="/register">register your own character.</a>
|
||||||
|
|
||||||
<br><br>
|
<br><br>
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
{flashhtml}
|
{error_message}
|
||||||
|
|
||||||
<form action="/register" method="post">
|
<form action="/register" method="post">
|
||||||
{csrf}
|
{csrf_field}
|
||||||
<table width="80%">
|
<table width="80%">
|
||||||
<tr>
|
<tr>
|
||||||
<td width="20%">Username:</td>
|
<td width="20%">Username:</td>
|
||||||
<td>
|
<td>
|
||||||
<input type="text" name="username" size="30" maxlength="30">
|
<input type="text" name="username" size="30" maxlength="30" value="{username}" required>
|
||||||
<br>
|
<br>
|
||||||
Usernames must be 30 alphanumeric characters or less.
|
Usernames must be 30 alphanumeric characters or less.
|
||||||
<br><br><br>
|
<br><br><br>
|
||||||
@ -14,45 +14,29 @@
|
|||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td>Password:</td>
|
<td>Password:</td>
|
||||||
<td><input type="password" name="password1" size="30" maxlength="10"></td>
|
<td><input type="password" name="password" size="30" required></td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td>Verify Password:</td>
|
<td>Verify Password:</td>
|
||||||
<td>
|
<td>
|
||||||
<input type="password" name="password2" size="30" maxlength="10">
|
<input type="password" name="confirm_password" size="30" required>
|
||||||
<br>
|
<br>
|
||||||
Passwords must be 10 alphanumeric characters or less.
|
Passwords must be at least 6 characters.
|
||||||
<br><br><br>
|
<br><br><br>
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td>Email Address:</td>
|
<td>Email Address:</td>
|
||||||
<td><input type="email" name="email1" size="30" maxlength="100"></td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td>Verify Email:</td>
|
|
||||||
<td>
|
<td>
|
||||||
<input type="text" name="email2" size="30" maxlength="100">
|
<input type="email" name="email" size="30" maxlength="100" value="{email}" required>
|
||||||
{verifytext}
|
<br>
|
||||||
|
A valid email address is required.
|
||||||
<br><br><br>
|
<br><br><br>
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
|
||||||
<td>Character Class:</td>
|
|
||||||
<td>
|
|
||||||
<select name="charclass">
|
|
||||||
<option value="1">{class1name}</option>
|
|
||||||
<option value="2">{class2name}</option>
|
|
||||||
<option value="3">{class3name}</option>
|
|
||||||
</select>
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td colspan="2">See <a href="/help">Help</a> for more information about character classes.<br><br></td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
<tr>
|
||||||
<td colspan="2">
|
<td colspan="2">
|
||||||
<input type="submit" name="submit" value="Submit">
|
<input type="submit" name="submit" value="Register">
|
||||||
<input type="reset" name="reset" value="Reset">
|
<input type="reset" name="reset" value="Reset">
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user