first commit

This commit is contained in:
Sky Johnson 2025-08-15 14:23:09 -05:00
parent baa56b79bd
commit 5bcaa4c89f
14 changed files with 1313 additions and 0 deletions

105
auth/auth.go Normal file
View File

@ -0,0 +1,105 @@
package auth
import (
sushi "git.sharkk.net/Sharkk/Sushi"
"git.sharkk.net/Sharkk/Sushi/session"
"github.com/valyala/fasthttp"
)
const UserCtxKey = "user"
// Middleware adds authentication handling
func Middleware(userLookup func(int) any) sushi.Middleware {
return func(next sushi.Handler) sushi.Handler {
return func(ctx sushi.Ctx, params []string) {
sess := session.GetCurrentSession(ctx)
if sess != nil && sess.UserID > 0 && userLookup != nil {
user := userLookup(sess.UserID)
if user != nil {
ctx.SetUserValue(UserCtxKey, user)
} else {
sess.SetUserID(0)
session.StoreSession(sess)
}
}
next(ctx, params)
}
}
}
// RequireAuth middleware that redirects unauthenticated users
func RequireAuth(redirectPath ...string) sushi.Middleware {
redirect := "/login"
if len(redirectPath) > 0 && redirectPath[0] != "" {
redirect = redirectPath[0]
}
return func(next sushi.Handler) sushi.Handler {
return func(ctx sushi.Ctx, params []string) {
if !IsAuthenticated(ctx) {
ctx.Redirect(redirect, fasthttp.StatusFound)
return
}
next(ctx, params)
}
}
}
// RequireGuest middleware that redirects authenticated users
func RequireGuest(redirectPath ...string) sushi.Middleware {
redirect := "/"
if len(redirectPath) > 0 && redirectPath[0] != "" {
redirect = redirectPath[0]
}
return func(next sushi.Handler) sushi.Handler {
return func(ctx sushi.Ctx, params []string) {
if IsAuthenticated(ctx) {
ctx.Redirect(redirect, fasthttp.StatusFound)
return
}
next(ctx, params)
}
}
}
// IsAuthenticated checks if the current request is from an authenticated user
func IsAuthenticated(ctx sushi.Ctx) bool {
user := ctx.UserValue(UserCtxKey)
return user != nil
}
// GetCurrentUser returns the current authenticated user
func GetCurrentUser(ctx sushi.Ctx) any {
return ctx.UserValue(UserCtxKey)
}
// Login authenticates a user session
func Login(ctx sushi.Ctx, userID int, user any) {
sess := session.GetCurrentSession(ctx)
if sess != nil {
sess.SetUserID(userID)
sess.RegenerateID()
session.StoreSession(sess)
ctx.SetUserValue(session.SessionCtxKey, sess)
ctx.SetUserValue(UserCtxKey, user)
session.SetSessionCookie(ctx, sess.ID)
}
}
// Logout clears the user session
func Logout(ctx sushi.Ctx) {
sess := session.GetCurrentSession(ctx)
if sess != nil {
sess.SetUserID(0)
sess.RegenerateID()
session.StoreSession(sess)
ctx.SetUserValue(session.SessionCtxKey, sess)
session.SetSessionCookie(ctx, sess.ID)
}
ctx.SetUserValue(UserCtxKey, nil)
}

77
cookies.go Normal file
View File

@ -0,0 +1,77 @@
package sushi
import (
"time"
"github.com/valyala/fasthttp"
)
type CookieOptions struct {
Name string
Value string
Path string
Domain string
Expires time.Time
MaxAge int
Secure bool
HTTPOnly bool
SameSite string
}
func SetSecureCookie(ctx *fasthttp.RequestCtx, opts CookieOptions) {
cookie := &fasthttp.Cookie{}
cookie.SetKey(opts.Name)
cookie.SetValue(opts.Value)
if opts.Path != "" {
cookie.SetPath(opts.Path)
} else {
cookie.SetPath("/")
}
if opts.Domain != "" {
cookie.SetDomain(opts.Domain)
}
if !opts.Expires.IsZero() {
cookie.SetExpire(opts.Expires)
}
if opts.MaxAge > 0 {
cookie.SetMaxAge(opts.MaxAge)
}
cookie.SetSecure(opts.Secure)
cookie.SetHTTPOnly(opts.HTTPOnly)
switch opts.SameSite {
case "strict":
cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
case "lax":
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
case "none":
cookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
default:
cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
}
ctx.Response.Header.SetCookie(cookie)
}
func GetCookie(ctx *fasthttp.RequestCtx, name string) string {
return string(ctx.Request.Header.Cookie(name))
}
func DeleteCookie(ctx *fasthttp.RequestCtx, name string) {
SetSecureCookie(ctx, CookieOptions{
Name: name,
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
MaxAge: -1,
HTTPOnly: true,
Secure: true,
SameSite: "lax",
})
}

138
csrf/csrf.go Normal file
View File

@ -0,0 +1,138 @@
package csrf
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"fmt"
sushi "git.sharkk.net/Sharkk/Sushi"
"git.sharkk.net/Sharkk/Sushi/session"
)
const (
CSRFTokenLength = 32
CSRFTokenFieldName = "_csrf_token"
CSRFSessionKey = "csrf_token"
SessionCtxKey = "session"
)
// GetCurrentSession retrieves the session from context
func GetCurrentSession(ctx sushi.Ctx) *session.Session {
if sess, ok := ctx.UserValue(SessionCtxKey).(*session.Session); ok {
return sess
}
return nil
}
// GenerateCSRFToken creates a new CSRF token and stores it in the session
func GenerateCSRFToken(ctx sushi.Ctx) string {
tokenBytes := make([]byte, CSRFTokenLength)
if _, err := rand.Read(tokenBytes); err != nil {
return ""
}
token := base64.URLEncoding.EncodeToString(tokenBytes)
if sess := GetCurrentSession(ctx); sess != nil {
sess.Set(CSRFSessionKey, token)
session.StoreSession(sess)
}
return token
}
// GetCSRFToken retrieves the current CSRF token from session, generating one if needed
func GetCSRFToken(ctx sushi.Ctx) string {
sess := GetCurrentSession(ctx)
if sess == nil {
return ""
}
if existingToken, ok := sess.Get(CSRFSessionKey); ok {
if tokenStr, ok := existingToken.(string); ok {
return tokenStr
}
}
return GenerateCSRFToken(ctx)
}
// ValidateCSRFToken verifies a CSRF token against the stored session token
func ValidateCSRFToken(ctx sushi.Ctx, submittedToken string) bool {
if submittedToken == "" {
return false
}
sess := GetCurrentSession(ctx)
if sess == nil {
return false
}
storedToken, ok := sess.Get(CSRFSessionKey)
if !ok {
return false
}
storedTokenStr, ok := storedToken.(string)
if !ok {
return false
}
return subtle.ConstantTimeCompare([]byte(submittedToken), []byte(storedTokenStr)) == 1
}
// CSRFHiddenField generates an HTML hidden input field with the CSRF token
func CSRFHiddenField(ctx sushi.Ctx) string {
token := GetCSRFToken(ctx)
if token == "" {
return ""
}
return fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`,
CSRFTokenFieldName, token)
}
// CSRFTokenMeta generates HTML meta tag for JavaScript access to CSRF token
func CSRFTokenMeta(ctx sushi.Ctx) string {
token := GetCSRFToken(ctx)
if token == "" {
return ""
}
return fmt.Sprintf(`<meta name="csrf-token" content="%s">`, token)
}
// ValidateFormCSRFToken validates CSRF token from form data
func ValidateFormCSRFToken(ctx sushi.Ctx) bool {
tokenBytes := ctx.PostArgs().Peek(CSRFTokenFieldName)
if len(tokenBytes) == 0 {
tokenBytes = ctx.QueryArgs().Peek(CSRFTokenFieldName)
}
if len(tokenBytes) == 0 {
return false
}
return ValidateCSRFToken(ctx, string(tokenBytes))
}
// Middleware returns middleware that automatically validates CSRF tokens
func Middleware() sushi.Middleware {
return func(next sushi.Handler) sushi.Handler {
return func(ctx sushi.Ctx, params []string) {
method := string(ctx.Method())
if method == "POST" || method == "PUT" || method == "PATCH" || method == "DELETE" {
if !ValidateFormCSRFToken(ctx) {
GenerateCSRFToken(ctx)
currentPath := string(ctx.Path())
ctx.Redirect(currentPath, 302)
return
}
}
next(ctx, params)
}
}
}

126
fs.go Normal file
View File

@ -0,0 +1,126 @@
package sushi
import (
"fmt"
"path"
"strings"
"time"
"github.com/valyala/fasthttp"
)
// StaticOptions configures static file serving
type StaticOptions struct {
Root string
IndexNames []string
GenerateIndexPages bool
Compress bool
CompressBrotli bool
AcceptByteRange bool
CacheDuration time.Duration
MaxAge int
}
// StaticFS creates a handler for serving static files with fasthttp.FS
func StaticFS(fsOptions StaticOptions) Handler {
// Set defaults
if fsOptions.Root == "" {
fsOptions.Root = "."
}
if len(fsOptions.IndexNames) == 0 {
fsOptions.IndexNames = []string{"index.html"}
}
fs := &fasthttp.FS{
Root: fsOptions.Root,
IndexNames: fsOptions.IndexNames,
GenerateIndexPages: fsOptions.GenerateIndexPages,
Compress: fsOptions.Compress,
CompressBrotli: fsOptions.CompressBrotli,
AcceptByteRange: fsOptions.AcceptByteRange,
CacheDuration: fsOptions.CacheDuration,
}
if fsOptions.MaxAge > 0 {
fs.PathRewrite = func(ctx *fasthttp.RequestCtx) []byte {
ctx.Response.Header.Set("Cache-Control", fmt.Sprintf("max-age=%d", fsOptions.MaxAge))
return ctx.Path()
}
}
fsHandler := fs.NewRequestHandler()
return func(ctx Ctx, params []string) {
fsHandler(ctx)
}
}
// Static creates a simple static file handler
func Static(root string) Handler {
return StaticFS(StaticOptions{Root: root})
}
// StaticFile serves a single file
func StaticFile(filePath string) Handler {
return func(ctx Ctx, params []string) {
fasthttp.ServeFile(ctx, filePath)
}
}
// StaticEmbed creates a handler for embedded files
func StaticEmbed(files map[string][]byte) Handler {
return func(ctx Ctx, params []string) {
requestPath := string(ctx.Path())
// Try to find the file
if data, exists := files[requestPath]; exists {
// Set content type based on extension
ext := path.Ext(requestPath)
contentType := getContentType(ext)
ctx.SetContentType(contentType)
ctx.Write(data)
return
}
// Try index files
if requestPath == "/" || strings.HasSuffix(requestPath, "/") {
indexPath := requestPath + "index.html"
if data, exists := files[indexPath]; exists {
ctx.SetContentType("text/html")
ctx.Write(data)
return
}
}
ctx.SetStatusCode(fasthttp.StatusNotFound)
}
}
func getContentType(ext string) string {
switch ext {
case ".html", ".htm":
return "text/html"
case ".css":
return "text/css"
case ".js":
return "application/javascript"
case ".json":
return "application/json"
case ".png":
return "image/png"
case ".jpg", ".jpeg":
return "image/jpeg"
case ".gif":
return "image/gif"
case ".svg":
return "image/svg+xml"
case ".ico":
return "image/x-icon"
case ".pdf":
return "application/pdf"
case ".txt":
return "text/plain"
default:
return "application/octet-stream"
}
}

12
go.mod
View File

@ -1,3 +1,15 @@
module git.sharkk.net/Sharkk/Sushi
go 1.24.1
require (
github.com/valyala/fasthttp v1.65.0
golang.org/x/crypto v0.41.0
)
require (
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
golang.org/x/sys v0.35.0 // indirect
)

14
go.sum Normal file
View File

@ -0,0 +1,14 @@
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.65.0 h1:j/u3uzFEGFfRxw79iYzJN+TteTJwbYkru9uDp3d0Yf8=
github.com/valyala/fasthttp v1.65.0/go.mod h1:P/93/YkKPMsKSnATEeELUCkG8a7Y+k99uxNHVbKINr4=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=

79
password/password.go Normal file
View File

@ -0,0 +1,79 @@
package password
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"fmt"
"strings"
"golang.org/x/crypto/argon2"
)
const (
argonTime = 1
argonMemory = 64 * 1024
argonThreads = 4
argonKeyLen = 32
)
// HashPassword creates an argon2id hash of the password
func HashPassword(password string) string {
salt := make([]byte, 16)
rand.Read(salt)
hash := argon2.IDKey([]byte(password), salt, argonTime, argonMemory, argonThreads, argonKeyLen)
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
encoded := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
argon2.Version, argonMemory, argonTime, argonThreads, b64Salt, b64Hash)
return encoded
}
// VerifyPassword checks if a password matches the hash
func VerifyPassword(password, encodedHash string) (bool, error) {
parts := strings.Split(encodedHash, "$")
if len(parts) != 6 {
return false, fmt.Errorf("invalid hash format")
}
if parts[1] != "argon2id" {
return false, fmt.Errorf("invalid hash variant")
}
var version int
_, err := fmt.Sscanf(parts[2], "v=%d", &version)
if err != nil {
return false, err
}
if version != argon2.Version {
return false, fmt.Errorf("incompatible argon2 version")
}
var m, t, p uint32
_, err = fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &m, &t, &p)
if err != nil {
return false, err
}
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
if err != nil {
return false, err
}
expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5])
if err != nil {
return false, err
}
hash := argon2.IDKey([]byte(password), salt, t, m, uint8(p), uint32(len(expectedHash)))
if subtle.ConstantTimeCompare(hash, expectedHash) == 1 {
return true, nil
}
return false, nil
}

275
router.go Normal file
View File

@ -0,0 +1,275 @@
package sushi
import (
"fmt"
"github.com/valyala/fasthttp"
)
type node struct {
segment string
handler Handler
children []*node
isDynamic bool
isWildcard bool
maxParams uint8
}
type Router struct {
get *node
post *node
put *node
patch *node
delete *node
middleware []Middleware
paramsBuffer []string
}
type Group struct {
router *Router
prefix string
middleware []Middleware
}
// New creates a new Router instance
func NewRouter() *Router {
return &Router{
get: &node{},
post: &node{},
put: &node{},
patch: &node{},
delete: &node{},
middleware: []Middleware{},
paramsBuffer: make([]string, 64),
}
}
// ServeHTTP implements the Handler interface for fasthttp
func (r *Router) ServeHTTP(ctx *fasthttp.RequestCtx) {
path := string(ctx.Path())
method := string(ctx.Method())
h, params, found := r.Lookup(method, path)
if !found {
ctx.SetStatusCode(fasthttp.StatusNotFound)
return
}
h(ctx, params)
}
// Handler returns a fasthttp request handler
func (r *Router) Handler() fasthttp.RequestHandler {
return r.ServeHTTP
}
// Use adds middleware to the router
func (r *Router) Use(mw ...Middleware) *Router {
r.middleware = append(r.middleware, mw...)
return r
}
// Group creates a new route group
func (r *Router) Group(prefix string) *Group {
return &Group{router: r, prefix: prefix, middleware: []Middleware{}}
}
// Use adds middleware to the group
func (g *Group) Use(mw ...Middleware) *Group {
g.middleware = append(g.middleware, mw...)
return g
}
// Group creates a nested group
func (g *Group) Group(prefix string) *Group {
return &Group{
router: g.router,
prefix: g.prefix + prefix,
middleware: append([]Middleware{}, g.middleware...),
}
}
// HTTP method handlers for Router
func (r *Router) Get(path string, h Handler) error { return r.Handle("GET", path, h) }
func (r *Router) Post(path string, h Handler) error { return r.Handle("POST", path, h) }
func (r *Router) Put(path string, h Handler) error { return r.Handle("PUT", path, h) }
func (r *Router) Patch(path string, h Handler) error { return r.Handle("PATCH", path, h) }
func (r *Router) Delete(path string, h Handler) error { return r.Handle("DELETE", path, h) }
// HTTP method handlers for Group
func (g *Group) Get(path string, h Handler) error { return g.Handle("GET", path, h) }
func (g *Group) Post(path string, h Handler) error { return g.Handle("POST", path, h) }
func (g *Group) Put(path string, h Handler) error { return g.Handle("PUT", path, h) }
func (g *Group) Patch(path string, h Handler) error { return g.Handle("PATCH", path, h) }
func (g *Group) Delete(path string, h Handler) error { return g.Handle("DELETE", path, h) }
// Handle registers a handler for the given method and path
func (r *Router) Handle(method, path string, h Handler) error {
root := r.methodNode(method)
if root == nil {
return fmt.Errorf("unsupported method: %s", method)
}
return r.addRoute(root, path, h, r.middleware)
}
// Handle registers a handler in the group
func (g *Group) Handle(method, path string, h Handler) error {
root := g.router.methodNode(method)
if root == nil {
return fmt.Errorf("unsupported method: %s", method)
}
mw := append([]Middleware{}, g.router.middleware...)
mw = append(mw, g.middleware...)
return g.router.addRoute(root, g.prefix+path, h, mw)
}
func (r *Router) methodNode(method string) *node {
switch method {
case "GET":
return r.get
case "POST":
return r.post
case "PUT":
return r.put
case "PATCH":
return r.patch
case "DELETE":
return r.delete
default:
return nil
}
}
func applyMiddleware(h Handler, mw []Middleware) Handler {
for i := len(mw) - 1; i >= 0; i-- {
h = mw[i](h)
}
return h
}
func readSegment(path string, start int) (segment string, end int, hasMore bool) {
if start >= len(path) {
return "", start, false
}
if path[start] == '/' {
start++
}
if start >= len(path) {
return "", start, false
}
end = start
for end < len(path) && path[end] != '/' {
end++
}
return path[start:end], end, end < len(path)
}
func (r *Router) addRoute(root *node, path string, h Handler, mw []Middleware) error {
h = applyMiddleware(h, mw)
if path == "/" {
root.handler = h
return nil
}
current := root
pos := 0
lastWC := false
count := uint8(0)
for {
seg, newPos, more := readSegment(path, pos)
if seg == "" {
break
}
isDyn := len(seg) > 1 && seg[0] == ':'
isWC := len(seg) > 0 && seg[0] == '*'
if isWC {
if lastWC || more {
return fmt.Errorf("wildcard must be the last segment in the path")
}
lastWC = true
}
if isDyn || isWC {
count++
}
var child *node
for _, c := range current.children {
if c.segment == seg {
child = c
break
}
}
if child == nil {
child = &node{segment: seg, isDynamic: isDyn, isWildcard: isWC}
current.children = append(current.children, child)
}
if child.maxParams < count {
child.maxParams = count
}
current = child
pos = newPos
}
current.handler = h
return nil
}
// Lookup finds a handler matching method and path
func (r *Router) Lookup(method, path string) (Handler, []string, bool) {
root := r.methodNode(method)
if root == nil {
return nil, nil, false
}
if path == "/" {
return root.handler, nil, root.handler != nil
}
buffer := r.paramsBuffer
if cap(buffer) < int(root.maxParams) {
buffer = make([]string, root.maxParams)
r.paramsBuffer = buffer
}
buffer = buffer[:0]
h, paramCount, found := match(root, path, 0, &buffer)
if !found {
return nil, nil, false
}
return h, buffer[:paramCount], true
}
func match(current *node, path string, start int, params *[]string) (Handler, int, bool) {
paramCount := 0
for _, c := range current.children {
if c.isWildcard {
rem := path[start:]
if len(rem) > 0 && rem[0] == '/' {
rem = rem[1:]
}
*params = append(*params, rem)
return c.handler, 1, c.handler != nil
}
}
seg, pos, more := readSegment(path, start)
if seg == "" {
return current.handler, 0, current.handler != nil
}
for _, c := range current.children {
if c.segment == seg || c.isDynamic {
if c.isDynamic {
*params = append(*params, seg)
paramCount++
}
if !more {
return c.handler, paramCount, c.handler != nil
}
h, nestedCount, ok := match(c, path, pos, params)
if ok {
return h, paramCount + nestedCount, true
}
}
}
return nil, 0, false
}

33
server.go Normal file
View File

@ -0,0 +1,33 @@
package sushi
// Listen starts the server on the specified address
func (a *App) Listen(addr ...string) error {
address := ":8080"
if len(addr) > 0 && addr[0] != "" {
address = addr[0]
}
return a.Server.ListenAndServe(address)
}
// ListenTLS starts the server with TLS on the specified address
func (a *App) ListenTLS(addr, certFile, keyFile string) error {
return a.Server.ListenAndServeTLS(addr, certFile, keyFile)
}
// Use adds middleware to the app's router
func (a *App) Use(mw ...Middleware) *App {
a.Router.Use(mw...)
return a
}
// Group creates a new route group on the app's router
func (a *App) Group(prefix string) *Group {
return a.Router.Group(prefix)
}
// HTTP method handlers for App
func (a *App) Get(path string, h Handler) error { return a.Router.Get(path, h) }
func (a *App) Post(path string, h Handler) error { return a.Router.Post(path, h) }
func (a *App) Put(path string, h Handler) error { return a.Router.Put(path, h) }
func (a *App) Patch(path string, h Handler) error { return a.Router.Patch(path, h) }
func (a *App) Delete(path string, h Handler) error { return a.Router.Delete(path, h) }

30
session/middleware.go Normal file
View File

@ -0,0 +1,30 @@
package session
import sushi "git.sharkk.net/Sharkk/Sushi"
// Middleware provides session handling
func Middleware() sushi.Middleware {
return func(next sushi.Handler) sushi.Handler {
return func(ctx sushi.Ctx, params []string) {
sessionID := sushi.GetCookie(ctx, SessionCookieName)
var sess *Session
if sessionID != "" {
if existingSess, exists := GetSession(sessionID); exists {
sess = existingSess
sess.Touch()
StoreSession(sess)
SetSessionCookie(ctx, sessionID)
}
}
if sess == nil {
sess = CreateSession(0) // Guest session
SetSessionCookie(ctx, sess.ID)
}
ctx.SetUserValue(SessionCtxKey, sess)
next(ctx, params)
}
}
}

280
session/session.go Normal file
View File

@ -0,0 +1,280 @@
package session
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"os"
"sync"
"time"
sushi "git.sharkk.net/Sharkk/Sushi"
)
const (
DefaultExpiration = 24 * time.Hour
IDLength = 32
SessionCookieName = "session_id"
SessionCtxKey = "session"
)
// Session represents a user session
type Session struct {
ID string `json:"id"`
UserID int `json:"user_id"`
ExpiresAt int64 `json:"expires_at"`
Data map[string]any `json:"data"`
}
// SessionManager handles session storage and persistence
type SessionManager struct {
mu sync.RWMutex
sessions map[string]*Session
filePath string
}
type sessionData struct {
UserID int `json:"user_id"`
ExpiresAt int64 `json:"expires_at"`
Data map[string]any `json:"data"`
}
var sessionManager *SessionManager
// InitSessions initializes the global session manager
func InitSessions(filePath string) {
if sessionManager != nil {
panic("session manager already initialized")
}
sessionManager = &SessionManager{
sessions: make(map[string]*Session),
filePath: filePath,
}
sessionManager.load()
}
// NewSession creates a new session
func NewSession(userID int) *Session {
return &Session{
ID: generateSessionID(),
UserID: userID,
ExpiresAt: time.Now().Add(DefaultExpiration).Unix(),
Data: make(map[string]any),
}
}
func generateSessionID() string {
bytes := make([]byte, IDLength)
rand.Read(bytes)
return hex.EncodeToString(bytes)
}
// Session methods
func (s *Session) IsExpired() bool {
return time.Now().Unix() > s.ExpiresAt
}
func (s *Session) Touch() {
s.ExpiresAt = time.Now().Add(DefaultExpiration).Unix()
}
func (s *Session) Set(key string, value any) {
s.Data[key] = value
}
func (s *Session) Get(key string) (any, bool) {
value, exists := s.Data[key]
return value, exists
}
func (s *Session) Delete(key string) {
delete(s.Data, key)
}
func (s *Session) SetFlash(key string, value any) {
s.Set("flash_"+key, value)
}
func (s *Session) GetFlash(key string) (any, bool) {
flashKey := "flash_" + key
value, exists := s.Get(flashKey)
if exists {
s.Delete(flashKey)
}
return value, exists
}
func (s *Session) GetFlashMessage(key string) string {
if flash, exists := s.GetFlash(key); exists {
if msg, ok := flash.(string); ok {
return msg
}
}
return ""
}
func (s *Session) RegenerateID() {
oldID := s.ID
s.ID = generateSessionID()
if sessionManager != nil {
sessionManager.mu.Lock()
delete(sessionManager.sessions, oldID)
sessionManager.sessions[s.ID] = s
sessionManager.mu.Unlock()
}
}
func (s *Session) SetUserID(userID int) {
s.UserID = userID
}
// GetCurrentSession retrieves the session from context
func GetCurrentSession(ctx sushi.Ctx) *Session {
if sess, ok := ctx.UserValue(SessionCtxKey).(*Session); ok {
return sess
}
return nil
}
// SessionManager methods
func (sm *SessionManager) Create(userID int) *Session {
sess := NewSession(userID)
sm.mu.Lock()
sm.sessions[sess.ID] = sess
sm.mu.Unlock()
return sess
}
func (sm *SessionManager) Get(sessionID string) (*Session, bool) {
sm.mu.RLock()
sess, exists := sm.sessions[sessionID]
sm.mu.RUnlock()
if !exists || sess.IsExpired() {
if exists {
sm.Delete(sessionID)
}
return nil, false
}
return sess, true
}
func (sm *SessionManager) Store(sess *Session) {
sm.mu.Lock()
sm.sessions[sess.ID] = sess
sm.mu.Unlock()
}
func (sm *SessionManager) Delete(sessionID string) {
sm.mu.Lock()
delete(sm.sessions, sessionID)
sm.mu.Unlock()
}
func (sm *SessionManager) Cleanup() {
sm.mu.Lock()
for id, sess := range sm.sessions {
if sess.IsExpired() {
delete(sm.sessions, id)
}
}
sm.mu.Unlock()
}
func (sm *SessionManager) load() {
if sm.filePath == "" {
return
}
data, err := os.ReadFile(sm.filePath)
if err != nil {
return
}
var sessionsData map[string]*sessionData
if err := json.Unmarshal(data, &sessionsData); err != nil {
return
}
now := time.Now().Unix()
sm.mu.Lock()
for id, data := range sessionsData {
if data != nil && data.ExpiresAt > now {
sess := &Session{
ID: id,
UserID: data.UserID,
ExpiresAt: data.ExpiresAt,
Data: data.Data,
}
if sess.Data == nil {
sess.Data = make(map[string]any)
}
sm.sessions[id] = sess
}
}
sm.mu.Unlock()
}
func (sm *SessionManager) Save() error {
if sm.filePath == "" {
return nil
}
sm.Cleanup()
sm.mu.RLock()
sessionsData := make(map[string]*sessionData, len(sm.sessions))
for id, sess := range sm.sessions {
sessionsData[id] = &sessionData{
UserID: sess.UserID,
ExpiresAt: sess.ExpiresAt,
Data: sess.Data,
}
}
data, err := json.MarshalIndent(sessionsData, "", "\t")
sm.mu.RUnlock()
if err != nil {
return err
}
return os.WriteFile(sm.filePath, data, 0600)
}
// Package-level session functions
func CreateSession(userID int) *Session {
return sessionManager.Create(userID)
}
func GetSession(sessionID string) (*Session, bool) {
return sessionManager.Get(sessionID)
}
func StoreSession(sess *Session) {
sessionManager.Store(sess)
}
func CleanupSessions() {
sessionManager.Cleanup()
}
func SaveSessions() error {
return sessionManager.Save()
}
func SetSessionCookie(ctx sushi.Ctx, sessionID string) {
sushi.SetSecureCookie(ctx, sushi.CookieOptions{
Name: SessionCookieName,
Value: sessionID,
Path: "/",
Expires: time.Now().Add(24 * time.Hour),
HTTPOnly: true,
Secure: sushi.IsHTTPS(ctx),
SameSite: "lax",
})
}

89
sushi.go Normal file
View File

@ -0,0 +1,89 @@
// Package sushi provides a complete FastHTTP-based web framework
// with routing, sessions, authentication, CSRF protection, and utilities.
package sushi
import (
"time"
"github.com/valyala/fasthttp"
)
func (h Handler) Serve(ctx Ctx, params []string) {
h(ctx, params)
}
func IsHTTPS(ctx Ctx) bool {
return ctx.IsTLS() ||
string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" ||
string(ctx.Request.Header.Peek("X-Forwarded-Scheme")) == "https"
}
// StandardHandler adapts a standard fasthttp.RequestHandler to the router's Handler
func StandardHandler(handler fasthttp.RequestHandler) Handler {
return func(ctx Ctx, _ []string) {
handler(ctx)
}
}
// ServerOptions contains configuration for the FastHTTP server
type ServerOptions struct {
Addr string
ReadTimeout time.Duration
WriteTimeout time.Duration
IdleTimeout time.Duration
MaxConnsPerIP int
MaxRequestsPerConn int
MaxRequestBodySize int
ReduceMemoryUsage bool
DisableKeepalive bool
TCPKeepalive bool
TCPKeepalivePeriod time.Duration
}
type App struct {
*fasthttp.Server
Router *Router
}
// New creates a new App instance with FastHTTP server and router
func New(opts ...ServerOptions) *App {
var options ServerOptions
if len(opts) > 0 {
options = opts[0]
}
// Set defaults
if options.ReadTimeout == 0 {
options.ReadTimeout = 10 * time.Second
}
if options.WriteTimeout == 0 {
options.WriteTimeout = 10 * time.Second
}
if options.IdleTimeout == 0 {
options.IdleTimeout = 60 * time.Second
}
if options.MaxRequestBodySize == 0 {
options.MaxRequestBodySize = 4 * 1024 * 1024 // 4MB
}
router := NewRouter()
app := &App{
Server: &fasthttp.Server{
Handler: router.Handler(),
ReadTimeout: options.ReadTimeout,
WriteTimeout: options.WriteTimeout,
IdleTimeout: options.IdleTimeout,
MaxConnsPerIP: options.MaxConnsPerIP,
MaxRequestsPerConn: options.MaxRequestsPerConn,
MaxRequestBodySize: options.MaxRequestBodySize,
ReduceMemoryUsage: options.ReduceMemoryUsage,
DisableKeepalive: options.DisableKeepalive,
TCPKeepalive: options.TCPKeepalive,
TCPKeepalivePeriod: options.TCPKeepalivePeriod,
},
Router: router,
}
return app
}

48
timing/timing.go Normal file
View File

@ -0,0 +1,48 @@
package timing
import (
"fmt"
"time"
sushi "git.sharkk.net/Sharkk/Sushi"
)
const RequestTimerKey = "request_start_time"
// Middleware adds request timing functionality
func Middleware() sushi.Middleware {
return func(next sushi.Handler) sushi.Handler {
return func(ctx sushi.Ctx, params []string) {
startTime := time.Now()
ctx.SetUserValue(RequestTimerKey, startTime)
next(ctx, params)
}
}
}
// GetRequestTime returns the total request processing time in seconds (formatted)
func GetRequestTime(ctx sushi.Ctx) string {
startTime, ok := ctx.UserValue(RequestTimerKey).(time.Time)
if !ok {
return "0"
}
duration := time.Since(startTime)
seconds := duration.Seconds()
if seconds < 0.001 {
return "0"
}
return fmt.Sprintf("%.3f", seconds)
}
// GetRequestDuration returns the raw duration
func GetRequestDuration(ctx sushi.Ctx) time.Duration {
startTime, ok := ctx.UserValue(RequestTimerKey).(time.Time)
if !ok {
return 0
}
return time.Since(startTime)
}

7
types.go Normal file
View File

@ -0,0 +1,7 @@
package sushi
import "github.com/valyala/fasthttp"
type Ctx = *fasthttp.RequestCtx
type Handler func(ctx Ctx, params []string)
type Middleware func(Handler) Handler