312 lines
6.4 KiB
Go
312 lines
6.4 KiB
Go
package sessions
|
|
|
|
import (
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/deneonet/benc"
|
|
bstd "github.com/deneonet/benc/std"
|
|
)
|
|
|
|
// Session stores data for a single user session
|
|
type Session struct {
|
|
ID string
|
|
Data map[string]any
|
|
CreatedAt time.Time
|
|
UpdatedAt time.Time
|
|
LastUsed time.Time
|
|
Expiry time.Time
|
|
dirty bool // Tracks if session has changes, not serialized
|
|
}
|
|
|
|
// Session pool to reduce allocations
|
|
var sessionPool = sync.Pool{
|
|
New: func() any {
|
|
return &Session{
|
|
Data: make(map[string]any, 8),
|
|
}
|
|
},
|
|
}
|
|
|
|
// BufPool for reusing serialization buffers
|
|
var bufPool = benc.NewBufPool(benc.WithBufferSize(4096))
|
|
|
|
// GetFromPool retrieves a session from the pool
|
|
func GetFromPool() *Session {
|
|
return sessionPool.Get().(*Session)
|
|
}
|
|
|
|
// ReturnToPool returns a session to the pool after cleaning it
|
|
func ReturnToPool(s *Session) {
|
|
if s == nil {
|
|
return
|
|
}
|
|
|
|
// Clean the session for reuse
|
|
s.ID = ""
|
|
for k := range s.Data {
|
|
delete(s.Data, k)
|
|
}
|
|
s.CreatedAt = time.Time{}
|
|
s.UpdatedAt = time.Time{}
|
|
s.LastUsed = time.Time{}
|
|
s.Expiry = time.Time{}
|
|
s.dirty = false
|
|
|
|
sessionPool.Put(s)
|
|
}
|
|
|
|
// NewSession creates a new session with the given ID
|
|
func NewSession(id string, maxAge int) *Session {
|
|
now := time.Now()
|
|
|
|
// Get from pool or create new
|
|
session := GetFromPool()
|
|
|
|
// Initialize
|
|
session.ID = id
|
|
session.CreatedAt = now
|
|
session.UpdatedAt = now
|
|
session.LastUsed = now
|
|
session.Expiry = now.Add(time.Duration(maxAge) * time.Second)
|
|
session.dirty = false
|
|
|
|
return session
|
|
}
|
|
|
|
// Get retrieves a value from the session
|
|
func (s *Session) Get(key string) any {
|
|
return s.Data[key]
|
|
}
|
|
|
|
// Set stores a value in the session
|
|
func (s *Session) Set(key string, value any) {
|
|
s.Data[key] = value
|
|
s.UpdatedAt = time.Now()
|
|
s.dirty = true
|
|
}
|
|
|
|
// Delete removes a value from the session
|
|
func (s *Session) Delete(key string) {
|
|
delete(s.Data, key)
|
|
s.UpdatedAt = time.Now()
|
|
s.dirty = true
|
|
}
|
|
|
|
// Clear removes all data from the session
|
|
func (s *Session) Clear() {
|
|
s.Data = make(map[string]any, 8)
|
|
s.UpdatedAt = time.Now()
|
|
s.dirty = true
|
|
}
|
|
|
|
// GetAll returns a copy of all session data
|
|
func (s *Session) GetAll() map[string]any {
|
|
copy := make(map[string]any, len(s.Data))
|
|
for k, v := range s.Data {
|
|
copy[k] = v
|
|
}
|
|
return copy
|
|
}
|
|
|
|
// IsExpired checks if the session has expired
|
|
func (s *Session) IsExpired() bool {
|
|
return time.Now().After(s.Expiry)
|
|
}
|
|
|
|
// UpdateLastUsed updates the last used time
|
|
// Only updates if at least 5 seconds have passed since last update
|
|
func (s *Session) UpdateLastUsed() {
|
|
now := time.Now()
|
|
if now.Sub(s.LastUsed) > 5*time.Second {
|
|
s.LastUsed = now
|
|
// Not marking dirty for LastUsed updates to reduce writes
|
|
}
|
|
}
|
|
|
|
// IsDirty returns if the session has unsaved changes
|
|
func (s *Session) IsDirty() bool {
|
|
return s.dirty
|
|
}
|
|
|
|
// ResetDirty marks the session as clean after saving
|
|
func (s *Session) ResetDirty() {
|
|
s.dirty = false
|
|
}
|
|
|
|
// SizePlain calculates the size needed to marshal the session
|
|
func (s *Session) SizePlain() (size int) {
|
|
// ID
|
|
size += bstd.SizeString(s.ID)
|
|
|
|
// Data (map of string to any)
|
|
// For simplicity, we store data as binary-encoded strings
|
|
// This is a simplification, in a real-world scenario you would handle
|
|
// different types differently
|
|
dataAsStrings := make(map[string]string)
|
|
for k, v := range s.Data {
|
|
dataAsStrings[k] = toString(v)
|
|
}
|
|
size += bstd.SizeMap(dataAsStrings, bstd.SizeString, bstd.SizeString)
|
|
|
|
// Time fields
|
|
size += bstd.SizeInt64() * 4 // Store Unix timestamps for all time fields
|
|
|
|
return size
|
|
}
|
|
|
|
// MarshalPlain serializes the session to binary
|
|
func (s *Session) MarshalPlain(n int, b []byte) (int, error) {
|
|
// ID
|
|
n = bstd.MarshalString(n, b, s.ID)
|
|
|
|
// Data
|
|
dataAsStrings := make(map[string]string)
|
|
for k, v := range s.Data {
|
|
dataAsStrings[k] = toString(v)
|
|
}
|
|
n = bstd.MarshalMap(n, b, dataAsStrings, bstd.MarshalString, bstd.MarshalString)
|
|
|
|
// Time fields as Unix timestamps
|
|
n = bstd.MarshalInt64(n, b, s.CreatedAt.Unix())
|
|
n = bstd.MarshalInt64(n, b, s.UpdatedAt.Unix())
|
|
n = bstd.MarshalInt64(n, b, s.LastUsed.Unix())
|
|
n = bstd.MarshalInt64(n, b, s.Expiry.Unix())
|
|
|
|
return n, nil
|
|
}
|
|
|
|
// UnmarshalPlain deserializes the session from binary
|
|
func (s *Session) UnmarshalPlain(n int, b []byte) (int, error) {
|
|
var err error
|
|
|
|
// ID
|
|
n, s.ID, err = bstd.UnmarshalString(n, b)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
|
|
// Data
|
|
var dataAsStrings map[string]string
|
|
n, dataAsStrings, err = bstd.UnmarshalMap[string, string](n, b, bstd.UnmarshalString, bstd.UnmarshalString)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
|
|
// Convert string data back to original types
|
|
s.Data = make(map[string]any, len(dataAsStrings))
|
|
for k, v := range dataAsStrings {
|
|
s.Data[k] = fromString(v)
|
|
}
|
|
|
|
// Time fields
|
|
var timestamp int64
|
|
|
|
// CreatedAt
|
|
n, timestamp, err = bstd.UnmarshalInt64(n, b)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
s.CreatedAt = time.Unix(timestamp, 0)
|
|
|
|
// UpdatedAt
|
|
n, timestamp, err = bstd.UnmarshalInt64(n, b)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
s.UpdatedAt = time.Unix(timestamp, 0)
|
|
|
|
// LastUsed
|
|
n, timestamp, err = bstd.UnmarshalInt64(n, b)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
s.LastUsed = time.Unix(timestamp, 0)
|
|
|
|
// Expiry
|
|
n, timestamp, err = bstd.UnmarshalInt64(n, b)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
s.Expiry = time.Unix(timestamp, 0)
|
|
|
|
return n, nil
|
|
}
|
|
|
|
// Marshal serializes the session using benc
|
|
func (s *Session) Marshal() ([]byte, error) {
|
|
size := s.SizePlain()
|
|
|
|
data, err := bufPool.Marshal(size, func(b []byte) (n int) {
|
|
n, _ = s.MarshalPlain(0, b)
|
|
return n
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return data, nil
|
|
}
|
|
|
|
// Unmarshal deserializes a session using benc
|
|
func Unmarshal(data []byte) (*Session, error) {
|
|
session := GetFromPool()
|
|
_, err := session.UnmarshalPlain(0, data)
|
|
if err != nil {
|
|
ReturnToPool(session)
|
|
return nil, err
|
|
}
|
|
return session, nil
|
|
}
|
|
|
|
// Helper functions to convert between any and string
|
|
// In a production environment, you would use a more robust serialization method for the map values
|
|
func toString(v any) string {
|
|
if v == nil {
|
|
return ""
|
|
}
|
|
switch t := v.(type) {
|
|
case string:
|
|
return t
|
|
case []byte:
|
|
return string(t)
|
|
case int:
|
|
return "i:" + string(rune(t))
|
|
case bool:
|
|
if t {
|
|
return "b:t"
|
|
}
|
|
return "b:f"
|
|
default:
|
|
return "u:" // unknown type
|
|
}
|
|
}
|
|
|
|
func fromString(s string) any {
|
|
if s == "" {
|
|
return nil
|
|
}
|
|
if len(s) < 2 {
|
|
return s
|
|
}
|
|
|
|
prefix := s[:2]
|
|
switch prefix {
|
|
case "i:":
|
|
if len(s) > 2 {
|
|
return int(rune(s[2]))
|
|
}
|
|
return 0
|
|
case "b:":
|
|
if len(s) > 2 && s[2] == 't' {
|
|
return true
|
|
}
|
|
return false
|
|
case "u:":
|
|
return nil
|
|
default:
|
|
return s
|
|
}
|
|
}
|