first commit
This commit is contained in:
parent
baa56b79bd
commit
5bcaa4c89f
105
auth/auth.go
Normal file
105
auth/auth.go
Normal file
@ -0,0 +1,105 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
sushi "git.sharkk.net/Sharkk/Sushi"
|
||||
"git.sharkk.net/Sharkk/Sushi/session"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const UserCtxKey = "user"
|
||||
|
||||
// Middleware adds authentication handling
|
||||
func Middleware(userLookup func(int) any) sushi.Middleware {
|
||||
return func(next sushi.Handler) sushi.Handler {
|
||||
return func(ctx sushi.Ctx, params []string) {
|
||||
sess := session.GetCurrentSession(ctx)
|
||||
if sess != nil && sess.UserID > 0 && userLookup != nil {
|
||||
user := userLookup(sess.UserID)
|
||||
if user != nil {
|
||||
ctx.SetUserValue(UserCtxKey, user)
|
||||
} else {
|
||||
sess.SetUserID(0)
|
||||
session.StoreSession(sess)
|
||||
}
|
||||
}
|
||||
next(ctx, params)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAuth middleware that redirects unauthenticated users
|
||||
func RequireAuth(redirectPath ...string) sushi.Middleware {
|
||||
redirect := "/login"
|
||||
if len(redirectPath) > 0 && redirectPath[0] != "" {
|
||||
redirect = redirectPath[0]
|
||||
}
|
||||
|
||||
return func(next sushi.Handler) sushi.Handler {
|
||||
return func(ctx sushi.Ctx, params []string) {
|
||||
if !IsAuthenticated(ctx) {
|
||||
ctx.Redirect(redirect, fasthttp.StatusFound)
|
||||
return
|
||||
}
|
||||
next(ctx, params)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RequireGuest middleware that redirects authenticated users
|
||||
func RequireGuest(redirectPath ...string) sushi.Middleware {
|
||||
redirect := "/"
|
||||
if len(redirectPath) > 0 && redirectPath[0] != "" {
|
||||
redirect = redirectPath[0]
|
||||
}
|
||||
|
||||
return func(next sushi.Handler) sushi.Handler {
|
||||
return func(ctx sushi.Ctx, params []string) {
|
||||
if IsAuthenticated(ctx) {
|
||||
ctx.Redirect(redirect, fasthttp.StatusFound)
|
||||
return
|
||||
}
|
||||
next(ctx, params)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsAuthenticated checks if the current request is from an authenticated user
|
||||
func IsAuthenticated(ctx sushi.Ctx) bool {
|
||||
user := ctx.UserValue(UserCtxKey)
|
||||
return user != nil
|
||||
}
|
||||
|
||||
// GetCurrentUser returns the current authenticated user
|
||||
func GetCurrentUser(ctx sushi.Ctx) any {
|
||||
return ctx.UserValue(UserCtxKey)
|
||||
}
|
||||
|
||||
// Login authenticates a user session
|
||||
func Login(ctx sushi.Ctx, userID int, user any) {
|
||||
sess := session.GetCurrentSession(ctx)
|
||||
if sess != nil {
|
||||
sess.SetUserID(userID)
|
||||
sess.RegenerateID()
|
||||
session.StoreSession(sess)
|
||||
|
||||
ctx.SetUserValue(session.SessionCtxKey, sess)
|
||||
ctx.SetUserValue(UserCtxKey, user)
|
||||
|
||||
session.SetSessionCookie(ctx, sess.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Logout clears the user session
|
||||
func Logout(ctx sushi.Ctx) {
|
||||
sess := session.GetCurrentSession(ctx)
|
||||
if sess != nil {
|
||||
sess.SetUserID(0)
|
||||
sess.RegenerateID()
|
||||
session.StoreSession(sess)
|
||||
|
||||
ctx.SetUserValue(session.SessionCtxKey, sess)
|
||||
session.SetSessionCookie(ctx, sess.ID)
|
||||
}
|
||||
|
||||
ctx.SetUserValue(UserCtxKey, nil)
|
||||
}
|
77
cookies.go
Normal file
77
cookies.go
Normal file
@ -0,0 +1,77 @@
|
||||
package sushi
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type CookieOptions struct {
|
||||
Name string
|
||||
Value string
|
||||
Path string
|
||||
Domain string
|
||||
Expires time.Time
|
||||
MaxAge int
|
||||
Secure bool
|
||||
HTTPOnly bool
|
||||
SameSite string
|
||||
}
|
||||
|
||||
func SetSecureCookie(ctx *fasthttp.RequestCtx, opts CookieOptions) {
|
||||
cookie := &fasthttp.Cookie{}
|
||||
|
||||
cookie.SetKey(opts.Name)
|
||||
cookie.SetValue(opts.Value)
|
||||
|
||||
if opts.Path != "" {
|
||||
cookie.SetPath(opts.Path)
|
||||
} else {
|
||||
cookie.SetPath("/")
|
||||
}
|
||||
|
||||
if opts.Domain != "" {
|
||||
cookie.SetDomain(opts.Domain)
|
||||
}
|
||||
|
||||
if !opts.Expires.IsZero() {
|
||||
cookie.SetExpire(opts.Expires)
|
||||
}
|
||||
|
||||
if opts.MaxAge > 0 {
|
||||
cookie.SetMaxAge(opts.MaxAge)
|
||||
}
|
||||
|
||||
cookie.SetSecure(opts.Secure)
|
||||
cookie.SetHTTPOnly(opts.HTTPOnly)
|
||||
|
||||
switch opts.SameSite {
|
||||
case "strict":
|
||||
cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
|
||||
case "lax":
|
||||
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
|
||||
case "none":
|
||||
cookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
|
||||
default:
|
||||
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
|
||||
}
|
||||
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
}
|
||||
|
||||
func GetCookie(ctx *fasthttp.RequestCtx, name string) string {
|
||||
return string(ctx.Request.Header.Cookie(name))
|
||||
}
|
||||
|
||||
func DeleteCookie(ctx *fasthttp.RequestCtx, name string) {
|
||||
SetSecureCookie(ctx, CookieOptions{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Expires: time.Unix(0, 0),
|
||||
MaxAge: -1,
|
||||
HTTPOnly: true,
|
||||
Secure: true,
|
||||
SameSite: "lax",
|
||||
})
|
||||
}
|
138
csrf/csrf.go
Normal file
138
csrf/csrf.go
Normal file
@ -0,0 +1,138 @@
|
||||
package csrf
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
sushi "git.sharkk.net/Sharkk/Sushi"
|
||||
"git.sharkk.net/Sharkk/Sushi/session"
|
||||
)
|
||||
|
||||
const (
|
||||
CSRFTokenLength = 32
|
||||
CSRFTokenFieldName = "_csrf_token"
|
||||
CSRFSessionKey = "csrf_token"
|
||||
SessionCtxKey = "session"
|
||||
)
|
||||
|
||||
// GetCurrentSession retrieves the session from context
|
||||
func GetCurrentSession(ctx sushi.Ctx) *session.Session {
|
||||
if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok {
|
||||
return sess
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateCSRFToken creates a new CSRF token and stores it in the session
|
||||
func GenerateCSRFToken(ctx sushi.Ctx) string {
|
||||
tokenBytes := make([]byte, CSRFTokenLength)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
token := base64.URLEncoding.EncodeToString(tokenBytes)
|
||||
|
||||
if sess := GetCurrentSession(ctx); sess != nil {
|
||||
sess.Set(CSRFSessionKey, token)
|
||||
session.StoreSession(sess)
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
// GetCSRFToken retrieves the current CSRF token from session, generating one if needed
|
||||
func GetCSRFToken(ctx sushi.Ctx) string {
|
||||
sess := GetCurrentSession(ctx)
|
||||
if sess == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if existingToken, ok := sess.Get(CSRFSessionKey); ok {
|
||||
if tokenStr, ok := existingToken.(string); ok {
|
||||
return tokenStr
|
||||
}
|
||||
}
|
||||
|
||||
return GenerateCSRFToken(ctx)
|
||||
}
|
||||
|
||||
// ValidateCSRFToken verifies a CSRF token against the stored session token
|
||||
func ValidateCSRFToken(ctx sushi.Ctx, submittedToken string) bool {
|
||||
if submittedToken == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
sess := GetCurrentSession(ctx)
|
||||
if sess == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
storedToken, ok := sess.Get(CSRFSessionKey)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
storedTokenStr, ok := storedToken.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedTokenStr)) == 1
|
||||
}
|
||||
|
||||
// CSRFHiddenField generates an HTML hidden input field with the CSRF token
|
||||
func CSRFHiddenField(ctx sushi.Ctx) string {
|
||||
token := GetCSRFToken(ctx)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`,
|
||||
CSRFTokenFieldName, token)
|
||||
}
|
||||
|
||||
// CSRFTokenMeta generates HTML meta tag for JavaScript access to CSRF token
|
||||
func CSRFTokenMeta(ctx sushi.Ctx) string {
|
||||
token := GetCSRFToken(ctx)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`<meta name="csrf-token" content="%s">`, token)
|
||||
}
|
||||
|
||||
// ValidateFormCSRFToken validates CSRF token from form data
|
||||
func ValidateFormCSRFToken(ctx sushi.Ctx) bool {
|
||||
tokenBytes := ctx.PostArgs().Peek(CSRFTokenFieldName)
|
||||
if len(tokenBytes) == 0 {
|
||||
tokenBytes = ctx.QueryArgs().Peek(CSRFTokenFieldName)
|
||||
}
|
||||
|
||||
if len(tokenBytes) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return ValidateCSRFToken(ctx, string(tokenBytes))
|
||||
}
|
||||
|
||||
// Middleware returns middleware that automatically validates CSRF tokens
|
||||
func Middleware() sushi.Middleware {
|
||||
return func(next sushi.Handler) sushi.Handler {
|
||||
return func(ctx sushi.Ctx, params []string) {
|
||||
method := string(ctx.Method())
|
||||
|
||||
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
|
||||
if !ValidateFormCSRFToken(ctx) {
|
||||
GenerateCSRFToken(ctx)
|
||||
currentPath := string(ctx.Path())
|
||||
ctx.Redirect(currentPath, 302)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
next(ctx, params)
|
||||
}
|
||||
}
|
||||
}
|
126
fs.go
Normal file
126
fs.go
Normal file
@ -0,0 +1,126 @@
|
||||
package sushi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// StaticOptions configures static file serving
|
||||
type StaticOptions struct {
|
||||
Root string
|
||||
IndexNames []string
|
||||
GenerateIndexPages bool
|
||||
Compress bool
|
||||
CompressBrotli bool
|
||||
AcceptByteRange bool
|
||||
CacheDuration time.Duration
|
||||
MaxAge int
|
||||
}
|
||||
|
||||
// StaticFS creates a handler for serving static files with fasthttp.FS
|
||||
func StaticFS(fsOptions StaticOptions) Handler {
|
||||
// Set defaults
|
||||
if fsOptions.Root == "" {
|
||||
fsOptions.Root = "."
|
||||
}
|
||||
if len(fsOptions.IndexNames) == 0 {
|
||||
fsOptions.IndexNames = []string{"index.html"}
|
||||
}
|
||||
|
||||
fs := &fasthttp.FS{
|
||||
Root: fsOptions.Root,
|
||||
IndexNames: fsOptions.IndexNames,
|
||||
GenerateIndexPages: fsOptions.GenerateIndexPages,
|
||||
Compress: fsOptions.Compress,
|
||||
CompressBrotli: fsOptions.CompressBrotli,
|
||||
AcceptByteRange: fsOptions.AcceptByteRange,
|
||||
CacheDuration: fsOptions.CacheDuration,
|
||||
}
|
||||
|
||||
if fsOptions.MaxAge > 0 {
|
||||
fs.PathRewrite = func(ctx *fasthttp.RequestCtx) []byte {
|
||||
ctx.Response.Header.Set("Cache-Control", fmt.Sprintf("max-age=%d", fsOptions.MaxAge))
|
||||
return ctx.Path()
|
||||
}
|
||||
}
|
||||
|
||||
fsHandler := fs.NewRequestHandler()
|
||||
|
||||
return func(ctx Ctx, params []string) {
|
||||
fsHandler(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Static creates a simple static file handler
|
||||
func Static(root string) Handler {
|
||||
return StaticFS(StaticOptions{Root: root})
|
||||
}
|
||||
|
||||
// StaticFile serves a single file
|
||||
func StaticFile(filePath string) Handler {
|
||||
return func(ctx Ctx, params []string) {
|
||||
fasthttp.ServeFile(ctx, filePath)
|
||||
}
|
||||
}
|
||||
|
||||
// StaticEmbed creates a handler for embedded files
|
||||
func StaticEmbed(files map[string][]byte) Handler {
|
||||
return func(ctx Ctx, params []string) {
|
||||
requestPath := string(ctx.Path())
|
||||
|
||||
// Try to find the file
|
||||
if data, exists := files[requestPath]; exists {
|
||||
// Set content type based on extension
|
||||
ext := path.Ext(requestPath)
|
||||
contentType := getContentType(ext)
|
||||
ctx.SetContentType(contentType)
|
||||
ctx.Write(data)
|
||||
return
|
||||
}
|
||||
|
||||
// Try index files
|
||||
if requestPath == "/" || strings.HasSuffix(requestPath, "/") {
|
||||
indexPath := requestPath + "index.html"
|
||||
if data, exists := files[indexPath]; exists {
|
||||
ctx.SetContentType("text/html")
|
||||
ctx.Write(data)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func getContentType(ext string) string {
|
||||
switch ext {
|
||||
case ".html", ".htm":
|
||||
return "text/html"
|
||||
case ".css":
|
||||
return "text/css"
|
||||
case ".js":
|
||||
return "application/javascript"
|
||||
case ".json":
|
||||
return "application/json"
|
||||
case ".png":
|
||||
return "image/png"
|
||||
case ".jpg", ".jpeg":
|
||||
return "image/jpeg"
|
||||
case ".gif":
|
||||
return "image/gif"
|
||||
case ".svg":
|
||||
return "image/svg+xml"
|
||||
case ".ico":
|
||||
return "image/x-icon"
|
||||
case ".pdf":
|
||||
return "application/pdf"
|
||||
case ".txt":
|
||||
return "text/plain"
|
||||
default:
|
||||
return "application/octet-stream"
|
||||
}
|
||||
}
|
12
go.mod
12
go.mod
@ -1,3 +1,15 @@
|
||||
module git.sharkk.net/Sharkk/Sushi
|
||||
|
||||
go 1.24.1
|
||||
|
||||
require (
|
||||
github.com/valyala/fasthttp v1.65.0
|
||||
golang.org/x/crypto v0.41.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
golang.org/x/sys v0.35.0 // indirect
|
||||
)
|
||||
|
14
go.sum
Normal file
14
go.sum
Normal file
@ -0,0 +1,14 @@
|
||||
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
|
||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8=
|
||||
github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
79
password/password.go
Normal file
79
password/password.go
Normal file
@ -0,0 +1,79 @@
|
||||
package password
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
const (
|
||||
argonTime = 1
|
||||
argonMemory = 64 * 1024
|
||||
argonThreads = 4
|
||||
argonKeyLen = 32
|
||||
)
|
||||
|
||||
// HashPassword creates an argon2id hash of the password
|
||||
func HashPassword(password string) string {
|
||||
salt := make([]byte, 16)
|
||||
rand.Read(salt)
|
||||
|
||||
hash := argon2.IDKey([]byte(password), salt, argonTime, argonMemory, argonThreads, argonKeyLen)
|
||||
|
||||
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
||||
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
||||
|
||||
encoded := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version, argonMemory, argonTime, argonThreads, b64Salt, b64Hash)
|
||||
|
||||
return encoded
|
||||
}
|
||||
|
||||
// VerifyPassword checks if a password matches the hash
|
||||
func VerifyPassword(password, encodedHash string) (bool, error) {
|
||||
parts := strings.Split(encodedHash, "$")
|
||||
if len(parts) != 6 {
|
||||
return false, fmt.Errorf("invalid hash format")
|
||||
}
|
||||
|
||||
if parts[1] != "argon2id" {
|
||||
return false, fmt.Errorf("invalid hash variant")
|
||||
}
|
||||
|
||||
var version int
|
||||
_, err := fmt.Sscanf(parts[2], "v=%d", &version)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if version != argon2.Version {
|
||||
return false, fmt.Errorf("incompatible argon2 version")
|
||||
}
|
||||
|
||||
var m, t, p uint32
|
||||
_, err = fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &m, &t, &p)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
hash := argon2.IDKey([]byte(password), salt, t, m, uint8(p), uint32(len(expectedHash)))
|
||||
|
||||
if subtle.ConstantTimeCompare(hash, expectedHash) == 1 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
275
router.go
Normal file
275
router.go
Normal file
@ -0,0 +1,275 @@
|
||||
package sushi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type node struct {
|
||||
segment string
|
||||
handler Handler
|
||||
children []*node
|
||||
isDynamic bool
|
||||
isWildcard bool
|
||||
maxParams uint8
|
||||
}
|
||||
|
||||
type Router struct {
|
||||
get *node
|
||||
post *node
|
||||
put *node
|
||||
patch *node
|
||||
delete *node
|
||||
middleware []Middleware
|
||||
paramsBuffer []string
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
router *Router
|
||||
prefix string
|
||||
middleware []Middleware
|
||||
}
|
||||
|
||||
// New creates a new Router instance
|
||||
func NewRouter() *Router {
|
||||
return &Router{
|
||||
get: &node{},
|
||||
post: &node{},
|
||||
put: &node{},
|
||||
patch: &node{},
|
||||
delete: &node{},
|
||||
middleware: []Middleware{},
|
||||
paramsBuffer: make([]string, 64),
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP implements the Handler interface for fasthttp
|
||||
func (r *Router) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
||||
path := string(ctx.Path())
|
||||
method := string(ctx.Method())
|
||||
|
||||
h, params, found := r.Lookup(method, path)
|
||||
if !found {
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
h(ctx, params)
|
||||
}
|
||||
|
||||
// Handler returns a fasthttp request handler
|
||||
func (r *Router) Handler() fasthttp.RequestHandler {
|
||||
return r.ServeHTTP
|
||||
}
|
||||
|
||||
// Use adds middleware to the router
|
||||
func (r *Router) Use(mw ...Middleware) *Router {
|
||||
r.middleware = append(r.middleware, mw...)
|
||||
return r
|
||||
}
|
||||
|
||||
// Group creates a new route group
|
||||
func (r *Router) Group(prefix string) *Group {
|
||||
return &Group{router: r, prefix: prefix, middleware: []Middleware{}}
|
||||
}
|
||||
|
||||
// Use adds middleware to the group
|
||||
func (g *Group) Use(mw ...Middleware) *Group {
|
||||
g.middleware = append(g.middleware, mw...)
|
||||
return g
|
||||
}
|
||||
|
||||
// Group creates a nested group
|
||||
func (g *Group) Group(prefix string) *Group {
|
||||
return &Group{
|
||||
router: g.router,
|
||||
prefix: g.prefix + prefix,
|
||||
middleware: append([]Middleware{}, g.middleware...),
|
||||
}
|
||||
}
|
||||
|
||||
// HTTP method handlers for Router
|
||||
func (r *Router) Get(path string, h Handler) error { return r.Handle("GET", path, h) }
|
||||
func (r *Router) Post(path string, h Handler) error { return r.Handle("POST", path, h) }
|
||||
func (r *Router) Put(path string, h Handler) error { return r.Handle("PUT", path, h) }
|
||||
func (r *Router) Patch(path string, h Handler) error { return r.Handle("PATCH", path, h) }
|
||||
func (r *Router) Delete(path string, h Handler) error { return r.Handle("DELETE", path, h) }
|
||||
|
||||
// HTTP method handlers for Group
|
||||
func (g *Group) Get(path string, h Handler) error { return g.Handle("GET", path, h) }
|
||||
func (g *Group) Post(path string, h Handler) error { return g.Handle("POST", path, h) }
|
||||
func (g *Group) Put(path string, h Handler) error { return g.Handle("PUT", path, h) }
|
||||
func (g *Group) Patch(path string, h Handler) error { return g.Handle("PATCH", path, h) }
|
||||
func (g *Group) Delete(path string, h Handler) error { return g.Handle("DELETE", path, h) }
|
||||
|
||||
// Handle registers a handler for the given method and path
|
||||
func (r *Router) Handle(method, path string, h Handler) error {
|
||||
root := r.methodNode(method)
|
||||
if root == nil {
|
||||
return fmt.Errorf("unsupported method: %s", method)
|
||||
}
|
||||
return r.addRoute(root, path, h, r.middleware)
|
||||
}
|
||||
|
||||
// Handle registers a handler in the group
|
||||
func (g *Group) Handle(method, path string, h Handler) error {
|
||||
root := g.router.methodNode(method)
|
||||
if root == nil {
|
||||
return fmt.Errorf("unsupported method: %s", method)
|
||||
}
|
||||
mw := append([]Middleware{}, g.router.middleware...)
|
||||
mw = append(mw, g.middleware...)
|
||||
return g.router.addRoute(root, g.prefix+path, h, mw)
|
||||
}
|
||||
|
||||
func (r *Router) methodNode(method string) *node {
|
||||
switch method {
|
||||
case "GET":
|
||||
return r.get
|
||||
case "POST":
|
||||
return r.post
|
||||
case "PUT":
|
||||
return r.put
|
||||
case "PATCH":
|
||||
return r.patch
|
||||
case "DELETE":
|
||||
return r.delete
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func applyMiddleware(h Handler, mw []Middleware) Handler {
|
||||
for i := len(mw) - 1; i >= 0; i-- {
|
||||
h = mw[i](h)
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func readSegment(path string, start int) (segment string, end int, hasMore bool) {
|
||||
if start >= len(path) {
|
||||
return "", start, false
|
||||
}
|
||||
if path[start] == '/' {
|
||||
start++
|
||||
}
|
||||
if start >= len(path) {
|
||||
return "", start, false
|
||||
}
|
||||
end = start
|
||||
for end < len(path) && path[end] != '/' {
|
||||
end++
|
||||
}
|
||||
return path[start:end], end, end < len(path)
|
||||
}
|
||||
|
||||
func (r *Router) addRoute(root *node, path string, h Handler, mw []Middleware) error {
|
||||
h = applyMiddleware(h, mw)
|
||||
if path == "/" {
|
||||
root.handler = h
|
||||
return nil
|
||||
}
|
||||
current := root
|
||||
pos := 0
|
||||
lastWC := false
|
||||
count := uint8(0)
|
||||
for {
|
||||
seg, newPos, more := readSegment(path, pos)
|
||||
if seg == "" {
|
||||
break
|
||||
}
|
||||
isDyn := len(seg) > 1 && seg[0] == ':'
|
||||
isWC := len(seg) > 0 && seg[0] == '*'
|
||||
if isWC {
|
||||
if lastWC || more {
|
||||
return fmt.Errorf("wildcard must be the last segment in the path")
|
||||
}
|
||||
lastWC = true
|
||||
}
|
||||
if isDyn || isWC {
|
||||
count++
|
||||
}
|
||||
var child *node
|
||||
for _, c := range current.children {
|
||||
if c.segment == seg {
|
||||
child = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if child == nil {
|
||||
child = &node{segment: seg, isDynamic: isDyn, isWildcard: isWC}
|
||||
current.children = append(current.children, child)
|
||||
}
|
||||
if child.maxParams < count {
|
||||
child.maxParams = count
|
||||
}
|
||||
current = child
|
||||
pos = newPos
|
||||
}
|
||||
current.handler = h
|
||||
return nil
|
||||
}
|
||||
|
||||
// Lookup finds a handler matching method and path
|
||||
func (r *Router) Lookup(method, path string) (Handler, []string, bool) {
|
||||
root := r.methodNode(method)
|
||||
if root == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
if path == "/" {
|
||||
return root.handler, nil, root.handler != nil
|
||||
}
|
||||
|
||||
buffer := r.paramsBuffer
|
||||
if cap(buffer) < int(root.maxParams) {
|
||||
buffer = make([]string, root.maxParams)
|
||||
r.paramsBuffer = buffer
|
||||
}
|
||||
buffer = buffer[:0]
|
||||
|
||||
h, paramCount, found := match(root, path, 0, &buffer)
|
||||
if !found {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
return h, buffer[:paramCount], true
|
||||
}
|
||||
|
||||
func match(current *node, path string, start int, params *[]string) (Handler, int, bool) {
|
||||
paramCount := 0
|
||||
|
||||
for _, c := range current.children {
|
||||
if c.isWildcard {
|
||||
rem := path[start:]
|
||||
if len(rem) > 0 && rem[0] == '/' {
|
||||
rem = rem[1:]
|
||||
}
|
||||
*params = append(*params, rem)
|
||||
return c.handler, 1, c.handler != nil
|
||||
}
|
||||
}
|
||||
|
||||
seg, pos, more := readSegment(path, start)
|
||||
if seg == "" {
|
||||
return current.handler, 0, current.handler != nil
|
||||
}
|
||||
|
||||
for _, c := range current.children {
|
||||
if c.segment == seg || c.isDynamic {
|
||||
if c.isDynamic {
|
||||
*params = append(*params, seg)
|
||||
paramCount++
|
||||
}
|
||||
if !more {
|
||||
return c.handler, paramCount, c.handler != nil
|
||||
}
|
||||
h, nestedCount, ok := match(c, path, pos, params)
|
||||
if ok {
|
||||
return h, paramCount + nestedCount, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, 0, false
|
||||
}
|
33
server.go
Normal file
33
server.go
Normal file
@ -0,0 +1,33 @@
|
||||
package sushi
|
||||
|
||||
// Listen starts the server on the specified address
|
||||
func (a *App) Listen(addr ...string) error {
|
||||
address := ":8080"
|
||||
if len(addr) > 0 && addr[0] != "" {
|
||||
address = addr[0]
|
||||
}
|
||||
return a.Server.ListenAndServe(address)
|
||||
}
|
||||
|
||||
// ListenTLS starts the server with TLS on the specified address
|
||||
func (a *App) ListenTLS(addr, certFile, keyFile string) error {
|
||||
return a.Server.ListenAndServeTLS(addr, certFile, keyFile)
|
||||
}
|
||||
|
||||
// Use adds middleware to the app's router
|
||||
func (a *App) Use(mw ...Middleware) *App {
|
||||
a.Router.Use(mw...)
|
||||
return a
|
||||
}
|
||||
|
||||
// Group creates a new route group on the app's router
|
||||
func (a *App) Group(prefix string) *Group {
|
||||
return a.Router.Group(prefix)
|
||||
}
|
||||
|
||||
// HTTP method handlers for App
|
||||
func (a *App) Get(path string, h Handler) error { return a.Router.Get(path, h) }
|
||||
func (a *App) Post(path string, h Handler) error { return a.Router.Post(path, h) }
|
||||
func (a *App) Put(path string, h Handler) error { return a.Router.Put(path, h) }
|
||||
func (a *App) Patch(path string, h Handler) error { return a.Router.Patch(path, h) }
|
||||
func (a *App) Delete(path string, h Handler) error { return a.Router.Delete(path, h) }
|
30
session/middleware.go
Normal file
30
session/middleware.go
Normal file
@ -0,0 +1,30 @@
|
||||
package session
|
||||
|
||||
import sushi "git.sharkk.net/Sharkk/Sushi"
|
||||
|
||||
// Middleware provides session handling
|
||||
func Middleware() sushi.Middleware {
|
||||
return func(next sushi.Handler) sushi.Handler {
|
||||
return func(ctx sushi.Ctx, params []string) {
|
||||
sessionID := sushi.GetCookie(ctx, SessionCookieName)
|
||||
var sess *Session
|
||||
|
||||
if sessionID != "" {
|
||||
if existingSess, exists := GetSession(sessionID); exists {
|
||||
sess = existingSess
|
||||
sess.Touch()
|
||||
StoreSession(sess)
|
||||
SetSessionCookie(ctx, sessionID)
|
||||
}
|
||||
}
|
||||
|
||||
if sess == nil {
|
||||
sess = CreateSession(0) // Guest session
|
||||
SetSessionCookie(ctx, sess.ID)
|
||||
}
|
||||
|
||||
ctx.SetUserValue(SessionCtxKey, sess)
|
||||
next(ctx, params)
|
||||
}
|
||||
}
|
||||
}
|
280
session/session.go
Normal file
280
session/session.go
Normal file
@ -0,0 +1,280 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
sushi "git.sharkk.net/Sharkk/Sushi"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultExpiration = 24 * time.Hour
|
||||
IDLength = 32
|
||||
SessionCookieName = "session_id"
|
||||
SessionCtxKey = "session"
|
||||
)
|
||||
|
||||
// Session represents a user session
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
UserID int `json:"user_id"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
Data map[string]any `json:"data"`
|
||||
}
|
||||
|
||||
// SessionManager handles session storage and persistence
|
||||
type SessionManager struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*Session
|
||||
filePath string
|
||||
}
|
||||
|
||||
type sessionData struct {
|
||||
UserID int `json:"user_id"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
Data map[string]any `json:"data"`
|
||||
}
|
||||
|
||||
var sessionManager *SessionManager
|
||||
|
||||
// InitSessions initializes the global session manager
|
||||
func InitSessions(filePath string) {
|
||||
if sessionManager != nil {
|
||||
panic("session manager already initialized")
|
||||
}
|
||||
|
||||
sessionManager = &SessionManager{
|
||||
sessions: make(map[string]*Session),
|
||||
filePath: filePath,
|
||||
}
|
||||
|
||||
sessionManager.load()
|
||||
}
|
||||
|
||||
// NewSession creates a new session
|
||||
func NewSession(userID int) *Session {
|
||||
return &Session{
|
||||
ID: generateSessionID(),
|
||||
UserID: userID,
|
||||
ExpiresAt: time.Now().Add(DefaultExpiration).Unix(),
|
||||
Data: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
func generateSessionID() string {
|
||||
bytes := make([]byte, IDLength)
|
||||
rand.Read(bytes)
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// Session methods
|
||||
func (s *Session) IsExpired() bool {
|
||||
return time.Now().Unix() > s.ExpiresAt
|
||||
}
|
||||
|
||||
func (s *Session) Touch() {
|
||||
s.ExpiresAt = time.Now().Add(DefaultExpiration).Unix()
|
||||
}
|
||||
|
||||
func (s *Session) Set(key string, value any) {
|
||||
s.Data[key] = value
|
||||
}
|
||||
|
||||
func (s *Session) Get(key string) (any, bool) {
|
||||
value, exists := s.Data[key]
|
||||
return value, exists
|
||||
}
|
||||
|
||||
func (s *Session) Delete(key string) {
|
||||
delete(s.Data, key)
|
||||
}
|
||||
|
||||
func (s *Session) SetFlash(key string, value any) {
|
||||
s.Set("flash_"+key, value)
|
||||
}
|
||||
|
||||
func (s *Session) GetFlash(key string) (any, bool) {
|
||||
flashKey := "flash_" + key
|
||||
value, exists := s.Get(flashKey)
|
||||
if exists {
|
||||
s.Delete(flashKey)
|
||||
}
|
||||
return value, exists
|
||||
}
|
||||
|
||||
func (s *Session) GetFlashMessage(key string) string {
|
||||
if flash, exists := s.GetFlash(key); exists {
|
||||
if msg, ok := flash.(string); ok {
|
||||
return msg
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *Session) RegenerateID() {
|
||||
oldID := s.ID
|
||||
s.ID = generateSessionID()
|
||||
|
||||
if sessionManager != nil {
|
||||
sessionManager.mu.Lock()
|
||||
delete(sessionManager.sessions, oldID)
|
||||
sessionManager.sessions[s.ID] = s
|
||||
sessionManager.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) SetUserID(userID int) {
|
||||
s.UserID = userID
|
||||
}
|
||||
|
||||
// GetCurrentSession retrieves the session from context
|
||||
func GetCurrentSession(ctx sushi.Ctx) *Session {
|
||||
if sess, ok := ctx.UserValue(SessionCtxKey).(*Session); ok {
|
||||
return sess
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SessionManager methods
|
||||
func (sm *SessionManager) Create(userID int) *Session {
|
||||
sess := NewSession(userID)
|
||||
sm.mu.Lock()
|
||||
sm.sessions[sess.ID] = sess
|
||||
sm.mu.Unlock()
|
||||
return sess
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Get(sessionID string) (*Session, bool) {
|
||||
sm.mu.RLock()
|
||||
sess, exists := sm.sessions[sessionID]
|
||||
sm.mu.RUnlock()
|
||||
|
||||
if !exists || sess.IsExpired() {
|
||||
if exists {
|
||||
sm.Delete(sessionID)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return sess, true
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Store(sess *Session) {
|
||||
sm.mu.Lock()
|
||||
sm.sessions[sess.ID] = sess
|
||||
sm.mu.Unlock()
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Delete(sessionID string) {
|
||||
sm.mu.Lock()
|
||||
delete(sm.sessions, sessionID)
|
||||
sm.mu.Unlock()
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Cleanup() {
|
||||
sm.mu.Lock()
|
||||
for id, sess := range sm.sessions {
|
||||
if sess.IsExpired() {
|
||||
delete(sm.sessions, id)
|
||||
}
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
}
|
||||
|
||||
func (sm *SessionManager) load() {
|
||||
if sm.filePath == "" {
|
||||
return
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(sm.filePath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var sessionsData map[string]*sessionData
|
||||
if err := json.Unmarshal(data, &sessionsData); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
sm.mu.Lock()
|
||||
for id, data := range sessionsData {
|
||||
if data != nil && data.ExpiresAt > now {
|
||||
sess := &Session{
|
||||
ID: id,
|
||||
UserID: data.UserID,
|
||||
ExpiresAt: data.ExpiresAt,
|
||||
Data: data.Data,
|
||||
}
|
||||
if sess.Data == nil {
|
||||
sess.Data = make(map[string]any)
|
||||
}
|
||||
sm.sessions[id] = sess
|
||||
}
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Save() error {
|
||||
if sm.filePath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
sm.Cleanup()
|
||||
|
||||
sm.mu.RLock()
|
||||
sessionsData := make(map[string]*sessionData, len(sm.sessions))
|
||||
for id, sess := range sm.sessions {
|
||||
sessionsData[id] = &sessionData{
|
||||
UserID: sess.UserID,
|
||||
ExpiresAt: sess.ExpiresAt,
|
||||
Data: sess.Data,
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(sessionsData, "", "\t")
|
||||
sm.mu.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(sm.filePath, data, 0600)
|
||||
}
|
||||
|
||||
// Package-level session functions
|
||||
func CreateSession(userID int) *Session {
|
||||
return sessionManager.Create(userID)
|
||||
}
|
||||
|
||||
func GetSession(sessionID string) (*Session, bool) {
|
||||
return sessionManager.Get(sessionID)
|
||||
}
|
||||
|
||||
func StoreSession(sess *Session) {
|
||||
sessionManager.Store(sess)
|
||||
}
|
||||
|
||||
func CleanupSessions() {
|
||||
sessionManager.Cleanup()
|
||||
}
|
||||
|
||||
func SaveSessions() error {
|
||||
return sessionManager.Save()
|
||||
}
|
||||
|
||||
func SetSessionCookie(ctx sushi.Ctx, sessionID string) {
|
||||
sushi.SetSecureCookie(ctx, sushi.CookieOptions{
|
||||
Name: SessionCookieName,
|
||||
Value: sessionID,
|
||||
Path: "/",
|
||||
Expires: time.Now().Add(24 * time.Hour),
|
||||
HTTPOnly: true,
|
||||
Secure: sushi.IsHTTPS(ctx),
|
||||
SameSite: "lax",
|
||||
})
|
||||
}
|
89
sushi.go
Normal file
89
sushi.go
Normal file
@ -0,0 +1,89 @@
|
||||
// Package sushi provides a complete FastHTTP-based web framework
|
||||
// with routing, sessions, authentication, CSRF protection, and utilities.
|
||||
package sushi
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func (h Handler) Serve(ctx Ctx, params []string) {
|
||||
h(ctx, params)
|
||||
}
|
||||
|
||||
func IsHTTPS(ctx Ctx) bool {
|
||||
return ctx.IsTLS() ||
|
||||
string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" ||
|
||||
string(ctx.Request.Header.Peek("X-Forwarded-Scheme")) == "https"
|
||||
}
|
||||
|
||||
// StandardHandler adapts a standard fasthttp.RequestHandler to the router's Handler
|
||||
func StandardHandler(handler fasthttp.RequestHandler) Handler {
|
||||
return func(ctx Ctx, _ []string) {
|
||||
handler(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// ServerOptions contains configuration for the FastHTTP server
|
||||
type ServerOptions struct {
|
||||
Addr string
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
IdleTimeout time.Duration
|
||||
MaxConnsPerIP int
|
||||
MaxRequestsPerConn int
|
||||
MaxRequestBodySize int
|
||||
ReduceMemoryUsage bool
|
||||
DisableKeepalive bool
|
||||
TCPKeepalive bool
|
||||
TCPKeepalivePeriod time.Duration
|
||||
}
|
||||
|
||||
type App struct {
|
||||
*fasthttp.Server
|
||||
Router *Router
|
||||
}
|
||||
|
||||
// New creates a new App instance with FastHTTP server and router
|
||||
func New(opts ...ServerOptions) *App {
|
||||
var options ServerOptions
|
||||
if len(opts) > 0 {
|
||||
options = opts[0]
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
if options.ReadTimeout == 0 {
|
||||
options.ReadTimeout = 10 * time.Second
|
||||
}
|
||||
if options.WriteTimeout == 0 {
|
||||
options.WriteTimeout = 10 * time.Second
|
||||
}
|
||||
if options.IdleTimeout == 0 {
|
||||
options.IdleTimeout = 60 * time.Second
|
||||
}
|
||||
if options.MaxRequestBodySize == 0 {
|
||||
options.MaxRequestBodySize = 4 * 1024 * 1024 // 4MB
|
||||
}
|
||||
|
||||
router := NewRouter()
|
||||
|
||||
app := &App{
|
||||
Server: &fasthttp.Server{
|
||||
Handler: router.Handler(),
|
||||
ReadTimeout: options.ReadTimeout,
|
||||
WriteTimeout: options.WriteTimeout,
|
||||
IdleTimeout: options.IdleTimeout,
|
||||
MaxConnsPerIP: options.MaxConnsPerIP,
|
||||
MaxRequestsPerConn: options.MaxRequestsPerConn,
|
||||
MaxRequestBodySize: options.MaxRequestBodySize,
|
||||
ReduceMemoryUsage: options.ReduceMemoryUsage,
|
||||
DisableKeepalive: options.DisableKeepalive,
|
||||
TCPKeepalive: options.TCPKeepalive,
|
||||
TCPKeepalivePeriod: options.TCPKeepalivePeriod,
|
||||
},
|
||||
Router: router,
|
||||
}
|
||||
|
||||
return app
|
||||
}
|
48
timing/timing.go
Normal file
48
timing/timing.go
Normal file
@ -0,0 +1,48 @@
|
||||
package timing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
sushi "git.sharkk.net/Sharkk/Sushi"
|
||||
)
|
||||
|
||||
const RequestTimerKey = "request_start_time"
|
||||
|
||||
// Middleware adds request timing functionality
|
||||
func Middleware() sushi.Middleware {
|
||||
return func(next sushi.Handler) sushi.Handler {
|
||||
return func(ctx sushi.Ctx, params []string) {
|
||||
startTime := time.Now()
|
||||
ctx.SetUserValue(RequestTimerKey, startTime)
|
||||
next(ctx, params)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetRequestTime returns the total request processing time in seconds (formatted)
|
||||
func GetRequestTime(ctx sushi.Ctx) string {
|
||||
startTime, ok := ctx.UserValue(RequestTimerKey).(time.Time)
|
||||
if !ok {
|
||||
return "0"
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
seconds := duration.Seconds()
|
||||
|
||||
if seconds < 0.001 {
|
||||
return "0"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%.3f", seconds)
|
||||
}
|
||||
|
||||
// GetRequestDuration returns the raw duration
|
||||
func GetRequestDuration(ctx sushi.Ctx) time.Duration {
|
||||
startTime, ok := ctx.UserValue(RequestTimerKey).(time.Time)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
|
||||
return time.Since(startTime)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user