From 5bcaa4c89f883d8ce10c4f382013b817ef42aaa4 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Fri, 15 Aug 2025 14:23:09 -0500 Subject: [PATCH] first commit --- auth/auth.go | 105 ++++++++++++++++ cookies.go | 77 ++++++++++++ csrf/csrf.go | 138 +++++++++++++++++++++ fs.go | 126 +++++++++++++++++++ go.mod | 12 ++ go.sum | 14 +++ password/password.go | 79 ++++++++++++ router.go | 275 +++++++++++++++++++++++++++++++++++++++++ server.go | 33 +++++ session/middleware.go | 30 +++++ session/session.go | 280 ++++++++++++++++++++++++++++++++++++++++++ sushi.go | 89 ++++++++++++++ timing/timing.go | 48 ++++++++ types.go | 7 ++ 14 files changed, 1313 insertions(+) create mode 100644 auth/auth.go create mode 100644 cookies.go create mode 100644 csrf/csrf.go create mode 100644 fs.go create mode 100644 go.sum create mode 100644 password/password.go create mode 100644 router.go create mode 100644 server.go create mode 100644 session/middleware.go create mode 100644 session/session.go create mode 100644 sushi.go create mode 100644 timing/timing.go create mode 100644 types.go diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 0000000..da5af57 --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,105 @@ +package auth + +import ( + sushi "git.sharkk.net/Sharkk/Sushi" + "git.sharkk.net/Sharkk/Sushi/session" + "github.com/valyala/fasthttp" +) + +const UserCtxKey = "user" + +// Middleware adds authentication handling +func Middleware(userLookup func(int) any) sushi.Middleware { + return func(next sushi.Handler) sushi.Handler { + return func(ctx sushi.Ctx, params []string) { + sess := session.GetCurrentSession(ctx) + if sess != nil && sess.UserID > 0 && userLookup != nil { + user := userLookup(sess.UserID) + if user != nil { + ctx.SetUserValue(UserCtxKey, user) + } else { + sess.SetUserID(0) + session.StoreSession(sess) + } + } + next(ctx, params) + } + } +} + +// RequireAuth middleware that redirects unauthenticated users +func RequireAuth(redirectPath ...string) sushi.Middleware { + redirect := "/login" + if len(redirectPath) > 0 && redirectPath[0] != "" { + redirect = redirectPath[0] + } + + return func(next sushi.Handler) sushi.Handler { + return func(ctx sushi.Ctx, params []string) { + if !IsAuthenticated(ctx) { + ctx.Redirect(redirect, fasthttp.StatusFound) + return + } + next(ctx, params) + } + } +} + +// RequireGuest middleware that redirects authenticated users +func RequireGuest(redirectPath ...string) sushi.Middleware { + redirect := "/" + if len(redirectPath) > 0 && redirectPath[0] != "" { + redirect = redirectPath[0] + } + + return func(next sushi.Handler) sushi.Handler { + return func(ctx sushi.Ctx, params []string) { + if IsAuthenticated(ctx) { + ctx.Redirect(redirect, fasthttp.StatusFound) + return + } + next(ctx, params) + } + } +} + +// IsAuthenticated checks if the current request is from an authenticated user +func IsAuthenticated(ctx sushi.Ctx) bool { + user := ctx.UserValue(UserCtxKey) + return user != nil +} + +// GetCurrentUser returns the current authenticated user +func GetCurrentUser(ctx sushi.Ctx) any { + return ctx.UserValue(UserCtxKey) +} + +// Login authenticates a user session +func Login(ctx sushi.Ctx, userID int, user any) { + sess := session.GetCurrentSession(ctx) + if sess != nil { + sess.SetUserID(userID) + sess.RegenerateID() + session.StoreSession(sess) + + ctx.SetUserValue(session.SessionCtxKey, sess) + ctx.SetUserValue(UserCtxKey, user) + + session.SetSessionCookie(ctx, sess.ID) + } +} + +// Logout clears the user session +func Logout(ctx sushi.Ctx) { + sess := session.GetCurrentSession(ctx) + if sess != nil { + sess.SetUserID(0) + sess.RegenerateID() + session.StoreSession(sess) + + ctx.SetUserValue(session.SessionCtxKey, sess) + session.SetSessionCookie(ctx, sess.ID) + } + + ctx.SetUserValue(UserCtxKey, nil) +} diff --git a/cookies.go b/cookies.go new file mode 100644 index 0000000..1b23964 --- /dev/null +++ b/cookies.go @@ -0,0 +1,77 @@ +package sushi + +import ( + "time" + + "github.com/valyala/fasthttp" +) + +type CookieOptions struct { + Name string + Value string + Path string + Domain string + Expires time.Time + MaxAge int + Secure bool + HTTPOnly bool + SameSite string +} + +func SetSecureCookie(ctx *fasthttp.RequestCtx, opts CookieOptions) { + cookie := &fasthttp.Cookie{} + + cookie.SetKey(opts.Name) + cookie.SetValue(opts.Value) + + if opts.Path != "" { + cookie.SetPath(opts.Path) + } else { + cookie.SetPath("/") + } + + if opts.Domain != "" { + cookie.SetDomain(opts.Domain) + } + + if !opts.Expires.IsZero() { + cookie.SetExpire(opts.Expires) + } + + if opts.MaxAge > 0 { + cookie.SetMaxAge(opts.MaxAge) + } + + cookie.SetSecure(opts.Secure) + cookie.SetHTTPOnly(opts.HTTPOnly) + + switch opts.SameSite { + case "strict": + cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode) + case "lax": + cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode) + case "none": + cookie.SetSameSite(fasthttp.CookieSameSiteNoneMode) + default: + cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode) + } + + ctx.Response.Header.SetCookie(cookie) +} + +func GetCookie(ctx *fasthttp.RequestCtx, name string) string { + return string(ctx.Request.Header.Cookie(name)) +} + +func DeleteCookie(ctx *fasthttp.RequestCtx, name string) { + SetSecureCookie(ctx, CookieOptions{ + Name: name, + Value: "", + Path: "/", + Expires: time.Unix(0, 0), + MaxAge: -1, + HTTPOnly: true, + Secure: true, + SameSite: "lax", + }) +} diff --git a/csrf/csrf.go b/csrf/csrf.go new file mode 100644 index 0000000..6367969 --- /dev/null +++ b/csrf/csrf.go @@ -0,0 +1,138 @@ +package csrf + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "fmt" + + sushi "git.sharkk.net/Sharkk/Sushi" + "git.sharkk.net/Sharkk/Sushi/session" +) + +const ( + CSRFTokenLength = 32 + CSRFTokenFieldName = "_csrf_token" + CSRFSessionKey = "csrf_token" + SessionCtxKey = "session" +) + +// GetCurrentSession retrieves the session from context +func GetCurrentSession(ctx sushi.Ctx) *session.Session { + if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok { + return sess + } + return nil +} + +// GenerateCSRFToken creates a new CSRF token and stores it in the session +func GenerateCSRFToken(ctx sushi.Ctx) string { + tokenBytes := make([]byte, CSRFTokenLength) + if _, err := rand.Read(tokenBytes); err != nil { + return "" + } + + token := base64.URLEncoding.EncodeToString(tokenBytes) + + if sess := GetCurrentSession(ctx); sess != nil { + sess.Set(CSRFSessionKey, token) + session.StoreSession(sess) + } + + return token +} + +// GetCSRFToken retrieves the current CSRF token from session, generating one if needed +func GetCSRFToken(ctx sushi.Ctx) string { + sess := GetCurrentSession(ctx) + if sess == nil { + return "" + } + + if existingToken, ok := sess.Get(CSRFSessionKey); ok { + if tokenStr, ok := existingToken.(string); ok { + return tokenStr + } + } + + return GenerateCSRFToken(ctx) +} + +// ValidateCSRFToken verifies a CSRF token against the stored session token +func ValidateCSRFToken(ctx sushi.Ctx, submittedToken string) bool { + if submittedToken == "" { + return false + } + + sess := GetCurrentSession(ctx) + if sess == nil { + return false + } + + storedToken, ok := sess.Get(CSRFSessionKey) + if !ok { + return false + } + + storedTokenStr, ok := storedToken.(string) + if !ok { + return false + } + + return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedTokenStr)) == 1 +} + +// CSRFHiddenField generates an HTML hidden input field with the CSRF token +func CSRFHiddenField(ctx sushi.Ctx) string { + token := GetCSRFToken(ctx) + if token == "" { + return "" + } + + return fmt.Sprintf(``, + CSRFTokenFieldName, token) +} + +// CSRFTokenMeta generates HTML meta tag for JavaScript access to CSRF token +func CSRFTokenMeta(ctx sushi.Ctx) string { + token := GetCSRFToken(ctx) + if token == "" { + return "" + } + + return fmt.Sprintf(``, token) +} + +// ValidateFormCSRFToken validates CSRF token from form data +func ValidateFormCSRFToken(ctx sushi.Ctx) bool { + tokenBytes := ctx.PostArgs().Peek(CSRFTokenFieldName) + if len(tokenBytes) == 0 { + tokenBytes = ctx.QueryArgs().Peek(CSRFTokenFieldName) + } + + if len(tokenBytes) == 0 { + return false + } + + return ValidateCSRFToken(ctx, string(tokenBytes)) +} + +// Middleware returns middleware that automatically validates CSRF tokens +func Middleware() sushi.Middleware { + return func(next sushi.Handler) sushi.Handler { + return func(ctx sushi.Ctx, params []string) { + method := string(ctx.Method()) + + if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" { + if !ValidateFormCSRFToken(ctx) { + GenerateCSRFToken(ctx) + currentPath := string(ctx.Path()) + ctx.Redirect(currentPath, 302) + return + } + } + + next(ctx, params) + } + } +} diff --git a/fs.go b/fs.go new file mode 100644 index 0000000..52eaf4e --- /dev/null +++ b/fs.go @@ -0,0 +1,126 @@ +package sushi + +import ( + "fmt" + "path" + "strings" + "time" + + "github.com/valyala/fasthttp" +) + +// StaticOptions configures static file serving +type StaticOptions struct { + Root string + IndexNames []string + GenerateIndexPages bool + Compress bool + CompressBrotli bool + AcceptByteRange bool + CacheDuration time.Duration + MaxAge int +} + +// StaticFS creates a handler for serving static files with fasthttp.FS +func StaticFS(fsOptions StaticOptions) Handler { + // Set defaults + if fsOptions.Root == "" { + fsOptions.Root = "." + } + if len(fsOptions.IndexNames) == 0 { + fsOptions.IndexNames = []string{"index.html"} + } + + fs := &fasthttp.FS{ + Root: fsOptions.Root, + IndexNames: fsOptions.IndexNames, + GenerateIndexPages: fsOptions.GenerateIndexPages, + Compress: fsOptions.Compress, + CompressBrotli: fsOptions.CompressBrotli, + AcceptByteRange: fsOptions.AcceptByteRange, + CacheDuration: fsOptions.CacheDuration, + } + + if fsOptions.MaxAge > 0 { + fs.PathRewrite = func(ctx *fasthttp.RequestCtx) []byte { + ctx.Response.Header.Set("Cache-Control", fmt.Sprintf("max-age=%d", fsOptions.MaxAge)) + return ctx.Path() + } + } + + fsHandler := fs.NewRequestHandler() + + return func(ctx Ctx, params []string) { + fsHandler(ctx) + } +} + +// Static creates a simple static file handler +func Static(root string) Handler { + return StaticFS(StaticOptions{Root: root}) +} + +// StaticFile serves a single file +func StaticFile(filePath string) Handler { + return func(ctx Ctx, params []string) { + fasthttp.ServeFile(ctx, filePath) + } +} + +// StaticEmbed creates a handler for embedded files +func StaticEmbed(files map[string][]byte) Handler { + return func(ctx Ctx, params []string) { + requestPath := string(ctx.Path()) + + // Try to find the file + if data, exists := files[requestPath]; exists { + // Set content type based on extension + ext := path.Ext(requestPath) + contentType := getContentType(ext) + ctx.SetContentType(contentType) + ctx.Write(data) + return + } + + // Try index files + if requestPath == "/" || strings.HasSuffix(requestPath, "/") { + indexPath := requestPath + "index.html" + if data, exists := files[indexPath]; exists { + ctx.SetContentType("text/html") + ctx.Write(data) + return + } + } + + ctx.SetStatusCode(fasthttp.StatusNotFound) + } +} + +func getContentType(ext string) string { + switch ext { + case ".html", ".htm": + return "text/html" + case ".css": + return "text/css" + case ".js": + return "application/javascript" + case ".json": + return "application/json" + case ".png": + return "image/png" + case ".jpg", ".jpeg": + return "image/jpeg" + case ".gif": + return "image/gif" + case ".svg": + return "image/svg+xml" + case ".ico": + return "image/x-icon" + case ".pdf": + return "application/pdf" + case ".txt": + return "text/plain" + default: + return "application/octet-stream" + } +} diff --git a/go.mod b/go.mod index d94bb3a..0fac803 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,15 @@ module git.sharkk.net/Sharkk/Sushi go 1.24.1 + +require ( + github.com/valyala/fasthttp v1.65.0 + golang.org/x/crypto v0.41.0 +) + +require ( + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + golang.org/x/sys v0.35.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..2188cbe --- /dev/null +++ b/go.sum @@ -0,0 +1,14 @@ +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8= +github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= diff --git a/password/password.go b/password/password.go new file mode 100644 index 0000000..2c09565 --- /dev/null +++ b/password/password.go @@ -0,0 +1,79 @@ +package password + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "fmt" + "strings" + + "golang.org/x/crypto/argon2" +) + +const ( + argonTime = 1 + argonMemory = 64 * 1024 + argonThreads = 4 + argonKeyLen = 32 +) + +// HashPassword creates an argon2id hash of the password +func HashPassword(password string) string { + salt := make([]byte, 16) + rand.Read(salt) + + hash := argon2.IDKey([]byte(password), salt, argonTime, argonMemory, argonThreads, argonKeyLen) + + b64Salt := base64.RawStdEncoding.EncodeToString(salt) + b64Hash := base64.RawStdEncoding.EncodeToString(hash) + + encoded := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", + argon2.Version, argonMemory, argonTime, argonThreads, b64Salt, b64Hash) + + return encoded +} + +// VerifyPassword checks if a password matches the hash +func VerifyPassword(password, encodedHash string) (bool, error) { + parts := strings.Split(encodedHash, "$") + if len(parts) != 6 { + return false, fmt.Errorf("invalid hash format") + } + + if parts[1] != "argon2id" { + return false, fmt.Errorf("invalid hash variant") + } + + var version int + _, err := fmt.Sscanf(parts[2], "v=%d", &version) + if err != nil { + return false, err + } + if version != argon2.Version { + return false, fmt.Errorf("incompatible argon2 version") + } + + var m, t, p uint32 + _, err = fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &m, &t, &p) + if err != nil { + return false, err + } + + salt, err := base64.RawStdEncoding.DecodeString(parts[4]) + if err != nil { + return false, err + } + + expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5]) + if err != nil { + return false, err + } + + hash := argon2.IDKey([]byte(password), salt, t, m, uint8(p), uint32(len(expectedHash))) + + if subtle.ConstantTimeCompare(hash, expectedHash) == 1 { + return true, nil + } + + return false, nil +} diff --git a/router.go b/router.go new file mode 100644 index 0000000..2e28862 --- /dev/null +++ b/router.go @@ -0,0 +1,275 @@ +package sushi + +import ( + "fmt" + + "github.com/valyala/fasthttp" +) + +type node struct { + segment string + handler Handler + children []*node + isDynamic bool + isWildcard bool + maxParams uint8 +} + +type Router struct { + get *node + post *node + put *node + patch *node + delete *node + middleware []Middleware + paramsBuffer []string +} + +type Group struct { + router *Router + prefix string + middleware []Middleware +} + +// New creates a new Router instance +func NewRouter() *Router { + return &Router{ + get: &node{}, + post: &node{}, + put: &node{}, + patch: &node{}, + delete: &node{}, + middleware: []Middleware{}, + paramsBuffer: make([]string, 64), + } +} + +// ServeHTTP implements the Handler interface for fasthttp +func (r *Router) ServeHTTP(ctx *fasthttp.RequestCtx) { + path := string(ctx.Path()) + method := string(ctx.Method()) + + h, params, found := r.Lookup(method, path) + if !found { + ctx.SetStatusCode(fasthttp.StatusNotFound) + return + } + + h(ctx, params) +} + +// Handler returns a fasthttp request handler +func (r *Router) Handler() fasthttp.RequestHandler { + return r.ServeHTTP +} + +// Use adds middleware to the router +func (r *Router) Use(mw ...Middleware) *Router { + r.middleware = append(r.middleware, mw...) + return r +} + +// Group creates a new route group +func (r *Router) Group(prefix string) *Group { + return &Group{router: r, prefix: prefix, middleware: []Middleware{}} +} + +// Use adds middleware to the group +func (g *Group) Use(mw ...Middleware) *Group { + g.middleware = append(g.middleware, mw...) + return g +} + +// Group creates a nested group +func (g *Group) Group(prefix string) *Group { + return &Group{ + router: g.router, + prefix: g.prefix + prefix, + middleware: append([]Middleware{}, g.middleware...), + } +} + +// HTTP method handlers for Router +func (r *Router) Get(path string, h Handler) error { return r.Handle("GET", path, h) } +func (r *Router) Post(path string, h Handler) error { return r.Handle("POST", path, h) } +func (r *Router) Put(path string, h Handler) error { return r.Handle("PUT", path, h) } +func (r *Router) Patch(path string, h Handler) error { return r.Handle("PATCH", path, h) } +func (r *Router) Delete(path string, h Handler) error { return r.Handle("DELETE", path, h) } + +// HTTP method handlers for Group +func (g *Group) Get(path string, h Handler) error { return g.Handle("GET", path, h) } +func (g *Group) Post(path string, h Handler) error { return g.Handle("POST", path, h) } +func (g *Group) Put(path string, h Handler) error { return g.Handle("PUT", path, h) } +func (g *Group) Patch(path string, h Handler) error { return g.Handle("PATCH", path, h) } +func (g *Group) Delete(path string, h Handler) error { return g.Handle("DELETE", path, h) } + +// Handle registers a handler for the given method and path +func (r *Router) Handle(method, path string, h Handler) error { + root := r.methodNode(method) + if root == nil { + return fmt.Errorf("unsupported method: %s", method) + } + return r.addRoute(root, path, h, r.middleware) +} + +// Handle registers a handler in the group +func (g *Group) Handle(method, path string, h Handler) error { + root := g.router.methodNode(method) + if root == nil { + return fmt.Errorf("unsupported method: %s", method) + } + mw := append([]Middleware{}, g.router.middleware...) + mw = append(mw, g.middleware...) + return g.router.addRoute(root, g.prefix+path, h, mw) +} + +func (r *Router) methodNode(method string) *node { + switch method { + case "GET": + return r.get + case "POST": + return r.post + case "PUT": + return r.put + case "PATCH": + return r.patch + case "DELETE": + return r.delete + default: + return nil + } +} + +func applyMiddleware(h Handler, mw []Middleware) Handler { + for i := len(mw) - 1; i >= 0; i-- { + h = mw[i](h) + } + return h +} + +func readSegment(path string, start int) (segment string, end int, hasMore bool) { + if start >= len(path) { + return "", start, false + } + if path[start] == '/' { + start++ + } + if start >= len(path) { + return "", start, false + } + end = start + for end < len(path) && path[end] != '/' { + end++ + } + return path[start:end], end, end < len(path) +} + +func (r *Router) addRoute(root *node, path string, h Handler, mw []Middleware) error { + h = applyMiddleware(h, mw) + if path == "/" { + root.handler = h + return nil + } + current := root + pos := 0 + lastWC := false + count := uint8(0) + for { + seg, newPos, more := readSegment(path, pos) + if seg == "" { + break + } + isDyn := len(seg) > 1 && seg[0] == ':' + isWC := len(seg) > 0 && seg[0] == '*' + if isWC { + if lastWC || more { + return fmt.Errorf("wildcard must be the last segment in the path") + } + lastWC = true + } + if isDyn || isWC { + count++ + } + var child *node + for _, c := range current.children { + if c.segment == seg { + child = c + break + } + } + if child == nil { + child = &node{segment: seg, isDynamic: isDyn, isWildcard: isWC} + current.children = append(current.children, child) + } + if child.maxParams < count { + child.maxParams = count + } + current = child + pos = newPos + } + current.handler = h + return nil +} + +// Lookup finds a handler matching method and path +func (r *Router) Lookup(method, path string) (Handler, []string, bool) { + root := r.methodNode(method) + if root == nil { + return nil, nil, false + } + if path == "/" { + return root.handler, nil, root.handler != nil + } + + buffer := r.paramsBuffer + if cap(buffer) < int(root.maxParams) { + buffer = make([]string, root.maxParams) + r.paramsBuffer = buffer + } + buffer = buffer[:0] + + h, paramCount, found := match(root, path, 0, &buffer) + if !found { + return nil, nil, false + } + + return h, buffer[:paramCount], true +} + +func match(current *node, path string, start int, params *[]string) (Handler, int, bool) { + paramCount := 0 + + for _, c := range current.children { + if c.isWildcard { + rem := path[start:] + if len(rem) > 0 && rem[0] == '/' { + rem = rem[1:] + } + *params = append(*params, rem) + return c.handler, 1, c.handler != nil + } + } + + seg, pos, more := readSegment(path, start) + if seg == "" { + return current.handler, 0, current.handler != nil + } + + for _, c := range current.children { + if c.segment == seg || c.isDynamic { + if c.isDynamic { + *params = append(*params, seg) + paramCount++ + } + if !more { + return c.handler, paramCount, c.handler != nil + } + h, nestedCount, ok := match(c, path, pos, params) + if ok { + return h, paramCount + nestedCount, true + } + } + } + + return nil, 0, false +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..5a6a06e --- /dev/null +++ b/server.go @@ -0,0 +1,33 @@ +package sushi + +// Listen starts the server on the specified address +func (a *App) Listen(addr ...string) error { + address := ":8080" + if len(addr) > 0 && addr[0] != "" { + address = addr[0] + } + return a.Server.ListenAndServe(address) +} + +// ListenTLS starts the server with TLS on the specified address +func (a *App) ListenTLS(addr, certFile, keyFile string) error { + return a.Server.ListenAndServeTLS(addr, certFile, keyFile) +} + +// Use adds middleware to the app's router +func (a *App) Use(mw ...Middleware) *App { + a.Router.Use(mw...) + return a +} + +// Group creates a new route group on the app's router +func (a *App) Group(prefix string) *Group { + return a.Router.Group(prefix) +} + +// HTTP method handlers for App +func (a *App) Get(path string, h Handler) error { return a.Router.Get(path, h) } +func (a *App) Post(path string, h Handler) error { return a.Router.Post(path, h) } +func (a *App) Put(path string, h Handler) error { return a.Router.Put(path, h) } +func (a *App) Patch(path string, h Handler) error { return a.Router.Patch(path, h) } +func (a *App) Delete(path string, h Handler) error { return a.Router.Delete(path, h) } diff --git a/session/middleware.go b/session/middleware.go new file mode 100644 index 0000000..a6013c3 --- /dev/null +++ b/session/middleware.go @@ -0,0 +1,30 @@ +package session + +import sushi "git.sharkk.net/Sharkk/Sushi" + +// Middleware provides session handling +func Middleware() sushi.Middleware { + return func(next sushi.Handler) sushi.Handler { + return func(ctx sushi.Ctx, params []string) { + sessionID := sushi.GetCookie(ctx, SessionCookieName) + var sess *Session + + if sessionID != "" { + if existingSess, exists := GetSession(sessionID); exists { + sess = existingSess + sess.Touch() + StoreSession(sess) + SetSessionCookie(ctx, sessionID) + } + } + + if sess == nil { + sess = CreateSession(0) // Guest session + SetSessionCookie(ctx, sess.ID) + } + + ctx.SetUserValue(SessionCtxKey, sess) + next(ctx, params) + } + } +} diff --git a/session/session.go b/session/session.go new file mode 100644 index 0000000..36c9774 --- /dev/null +++ b/session/session.go @@ -0,0 +1,280 @@ +package session + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "os" + "sync" + "time" + + sushi "git.sharkk.net/Sharkk/Sushi" +) + +const ( + DefaultExpiration = 24 * time.Hour + IDLength = 32 + SessionCookieName = "session_id" + SessionCtxKey = "session" +) + +// Session represents a user session +type Session struct { + ID string `json:"id"` + UserID int `json:"user_id"` + ExpiresAt int64 `json:"expires_at"` + Data map[string]any `json:"data"` +} + +// SessionManager handles session storage and persistence +type SessionManager struct { + mu sync.RWMutex + sessions map[string]*Session + filePath string +} + +type sessionData struct { + UserID int `json:"user_id"` + ExpiresAt int64 `json:"expires_at"` + Data map[string]any `json:"data"` +} + +var sessionManager *SessionManager + +// InitSessions initializes the global session manager +func InitSessions(filePath string) { + if sessionManager != nil { + panic("session manager already initialized") + } + + sessionManager = &SessionManager{ + sessions: make(map[string]*Session), + filePath: filePath, + } + + sessionManager.load() +} + +// NewSession creates a new session +func NewSession(userID int) *Session { + return &Session{ + ID: generateSessionID(), + UserID: userID, + ExpiresAt: time.Now().Add(DefaultExpiration).Unix(), + Data: make(map[string]any), + } +} + +func generateSessionID() string { + bytes := make([]byte, IDLength) + rand.Read(bytes) + return hex.EncodeToString(bytes) +} + +// Session methods +func (s *Session) IsExpired() bool { + return time.Now().Unix() > s.ExpiresAt +} + +func (s *Session) Touch() { + s.ExpiresAt = time.Now().Add(DefaultExpiration).Unix() +} + +func (s *Session) Set(key string, value any) { + s.Data[key] = value +} + +func (s *Session) Get(key string) (any, bool) { + value, exists := s.Data[key] + return value, exists +} + +func (s *Session) Delete(key string) { + delete(s.Data, key) +} + +func (s *Session) SetFlash(key string, value any) { + s.Set("flash_"+key, value) +} + +func (s *Session) GetFlash(key string) (any, bool) { + flashKey := "flash_" + key + value, exists := s.Get(flashKey) + if exists { + s.Delete(flashKey) + } + return value, exists +} + +func (s *Session) GetFlashMessage(key string) string { + if flash, exists := s.GetFlash(key); exists { + if msg, ok := flash.(string); ok { + return msg + } + } + return "" +} + +func (s *Session) RegenerateID() { + oldID := s.ID + s.ID = generateSessionID() + + if sessionManager != nil { + sessionManager.mu.Lock() + delete(sessionManager.sessions, oldID) + sessionManager.sessions[s.ID] = s + sessionManager.mu.Unlock() + } +} + +func (s *Session) SetUserID(userID int) { + s.UserID = userID +} + +// GetCurrentSession retrieves the session from context +func GetCurrentSession(ctx sushi.Ctx) *Session { + if sess, ok := ctx.UserValue(SessionCtxKey).(*Session); ok { + return sess + } + return nil +} + +// SessionManager methods +func (sm *SessionManager) Create(userID int) *Session { + sess := NewSession(userID) + sm.mu.Lock() + sm.sessions[sess.ID] = sess + sm.mu.Unlock() + return sess +} + +func (sm *SessionManager) Get(sessionID string) (*Session, bool) { + sm.mu.RLock() + sess, exists := sm.sessions[sessionID] + sm.mu.RUnlock() + + if !exists || sess.IsExpired() { + if exists { + sm.Delete(sessionID) + } + return nil, false + } + + return sess, true +} + +func (sm *SessionManager) Store(sess *Session) { + sm.mu.Lock() + sm.sessions[sess.ID] = sess + sm.mu.Unlock() +} + +func (sm *SessionManager) Delete(sessionID string) { + sm.mu.Lock() + delete(sm.sessions, sessionID) + sm.mu.Unlock() +} + +func (sm *SessionManager) Cleanup() { + sm.mu.Lock() + for id, sess := range sm.sessions { + if sess.IsExpired() { + delete(sm.sessions, id) + } + } + sm.mu.Unlock() +} + +func (sm *SessionManager) load() { + if sm.filePath == "" { + return + } + + data, err := os.ReadFile(sm.filePath) + if err != nil { + return + } + + var sessionsData map[string]*sessionData + if err := json.Unmarshal(data, &sessionsData); err != nil { + return + } + + now := time.Now().Unix() + sm.mu.Lock() + for id, data := range sessionsData { + if data != nil && data.ExpiresAt > now { + sess := &Session{ + ID: id, + UserID: data.UserID, + ExpiresAt: data.ExpiresAt, + Data: data.Data, + } + if sess.Data == nil { + sess.Data = make(map[string]any) + } + sm.sessions[id] = sess + } + } + sm.mu.Unlock() +} + +func (sm *SessionManager) Save() error { + if sm.filePath == "" { + return nil + } + + sm.Cleanup() + + sm.mu.RLock() + sessionsData := make(map[string]*sessionData, len(sm.sessions)) + for id, sess := range sm.sessions { + sessionsData[id] = &sessionData{ + UserID: sess.UserID, + ExpiresAt: sess.ExpiresAt, + Data: sess.Data, + } + } + + data, err := json.MarshalIndent(sessionsData, "", "\t") + sm.mu.RUnlock() + + if err != nil { + return err + } + + return os.WriteFile(sm.filePath, data, 0600) +} + +// Package-level session functions +func CreateSession(userID int) *Session { + return sessionManager.Create(userID) +} + +func GetSession(sessionID string) (*Session, bool) { + return sessionManager.Get(sessionID) +} + +func StoreSession(sess *Session) { + sessionManager.Store(sess) +} + +func CleanupSessions() { + sessionManager.Cleanup() +} + +func SaveSessions() error { + return sessionManager.Save() +} + +func SetSessionCookie(ctx sushi.Ctx, sessionID string) { + sushi.SetSecureCookie(ctx, sushi.CookieOptions{ + Name: SessionCookieName, + Value: sessionID, + Path: "/", + Expires: time.Now().Add(24 * time.Hour), + HTTPOnly: true, + Secure: sushi.IsHTTPS(ctx), + SameSite: "lax", + }) +} diff --git a/sushi.go b/sushi.go new file mode 100644 index 0000000..bb3876c --- /dev/null +++ b/sushi.go @@ -0,0 +1,89 @@ +// Package sushi provides a complete FastHTTP-based web framework +// with routing, sessions, authentication, CSRF protection, and utilities. +package sushi + +import ( + "time" + + "github.com/valyala/fasthttp" +) + +func (h Handler) Serve(ctx Ctx, params []string) { + h(ctx, params) +} + +func IsHTTPS(ctx Ctx) bool { + return ctx.IsTLS() || + string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" || + string(ctx.Request.Header.Peek("X-Forwarded-Scheme")) == "https" +} + +// StandardHandler adapts a standard fasthttp.RequestHandler to the router's Handler +func StandardHandler(handler fasthttp.RequestHandler) Handler { + return func(ctx Ctx, _ []string) { + handler(ctx) + } +} + +// ServerOptions contains configuration for the FastHTTP server +type ServerOptions struct { + Addr string + ReadTimeout time.Duration + WriteTimeout time.Duration + IdleTimeout time.Duration + MaxConnsPerIP int + MaxRequestsPerConn int + MaxRequestBodySize int + ReduceMemoryUsage bool + DisableKeepalive bool + TCPKeepalive bool + TCPKeepalivePeriod time.Duration +} + +type App struct { + *fasthttp.Server + Router *Router +} + +// New creates a new App instance with FastHTTP server and router +func New(opts ...ServerOptions) *App { + var options ServerOptions + if len(opts) > 0 { + options = opts[0] + } + + // Set defaults + if options.ReadTimeout == 0 { + options.ReadTimeout = 10 * time.Second + } + if options.WriteTimeout == 0 { + options.WriteTimeout = 10 * time.Second + } + if options.IdleTimeout == 0 { + options.IdleTimeout = 60 * time.Second + } + if options.MaxRequestBodySize == 0 { + options.MaxRequestBodySize = 4 * 1024 * 1024 // 4MB + } + + router := NewRouter() + + app := &App{ + Server: &fasthttp.Server{ + Handler: router.Handler(), + ReadTimeout: options.ReadTimeout, + WriteTimeout: options.WriteTimeout, + IdleTimeout: options.IdleTimeout, + MaxConnsPerIP: options.MaxConnsPerIP, + MaxRequestsPerConn: options.MaxRequestsPerConn, + MaxRequestBodySize: options.MaxRequestBodySize, + ReduceMemoryUsage: options.ReduceMemoryUsage, + DisableKeepalive: options.DisableKeepalive, + TCPKeepalive: options.TCPKeepalive, + TCPKeepalivePeriod: options.TCPKeepalivePeriod, + }, + Router: router, + } + + return app +} diff --git a/timing/timing.go b/timing/timing.go new file mode 100644 index 0000000..eda682d --- /dev/null +++ b/timing/timing.go @@ -0,0 +1,48 @@ +package timing + +import ( + "fmt" + "time" + + sushi "git.sharkk.net/Sharkk/Sushi" +) + +const RequestTimerKey = "request_start_time" + +// Middleware adds request timing functionality +func Middleware() sushi.Middleware { + return func(next sushi.Handler) sushi.Handler { + return func(ctx sushi.Ctx, params []string) { + startTime := time.Now() + ctx.SetUserValue(RequestTimerKey, startTime) + next(ctx, params) + } + } +} + +// GetRequestTime returns the total request processing time in seconds (formatted) +func GetRequestTime(ctx sushi.Ctx) string { + startTime, ok := ctx.UserValue(RequestTimerKey).(time.Time) + if !ok { + return "0" + } + + duration := time.Since(startTime) + seconds := duration.Seconds() + + if seconds < 0.001 { + return "0" + } + + return fmt.Sprintf("%.3f", seconds) +} + +// GetRequestDuration returns the raw duration +func GetRequestDuration(ctx sushi.Ctx) time.Duration { + startTime, ok := ctx.UserValue(RequestTimerKey).(time.Time) + if !ok { + return 0 + } + + return time.Since(startTime) +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..b4b20fa --- /dev/null +++ b/types.go @@ -0,0 +1,7 @@ +package sushi + +import "github.com/valyala/fasthttp" + +type Ctx = *fasthttp.RequestCtx +type Handler func(ctx Ctx, params []string) +type Middleware func(Handler) Handler