Moonshark/core/sessions/Session.go

351 lines
7.3 KiB
Go

package sessions
import (
"sync"
"time"
"github.com/deneonet/benc"
bstd "github.com/deneonet/benc/std"
)
// Session stores data for a single user session
type Session struct {
ID string
Data map[string]any
CreatedAt time.Time
UpdatedAt time.Time
LastUsed time.Time
Expiry time.Time
dirty bool // Tracks if session has changes, not serialized
}
// Session pool to reduce allocations
var sessionPool = sync.Pool{
New: func() any {
return &Session{
Data: make(map[string]any, 8),
}
},
}
// BufPool for reusing serialization buffers
var bufPool = benc.NewBufPool(benc.WithBufferSize(4096))
// GetFromPool retrieves a session from the pool
func GetFromPool() *Session {
return sessionPool.Get().(*Session)
}
// ReturnToPool returns a session to the pool after cleaning it
func ReturnToPool(s *Session) {
if s == nil {
return
}
// Clean the session for reuse
s.ID = ""
for k := range s.Data {
delete(s.Data, k)
}
s.CreatedAt = time.Time{}
s.UpdatedAt = time.Time{}
s.LastUsed = time.Time{}
s.Expiry = time.Time{}
s.dirty = false
sessionPool.Put(s)
}
// NewSession creates a new session with the given ID
func NewSession(id string, maxAge int) *Session {
now := time.Now()
// Get from pool or create new
session := GetFromPool()
// Initialize
session.ID = id
session.CreatedAt = now
session.UpdatedAt = now
session.LastUsed = now
session.Expiry = now.Add(time.Duration(maxAge) * time.Second)
session.dirty = false
return session
}
// Get retrieves a value from the session
func (s *Session) Get(key string) any {
return s.Data[key]
}
// Set stores a value in the session
func (s *Session) Set(key string, value any) {
s.Data[key] = value
s.UpdatedAt = time.Now()
s.dirty = true
}
// Delete removes a value from the session
func (s *Session) Delete(key string) {
delete(s.Data, key)
s.UpdatedAt = time.Now()
s.dirty = true
}
// Clear removes all data from the session
func (s *Session) Clear() {
s.Data = make(map[string]any, 8)
s.UpdatedAt = time.Now()
s.dirty = true
}
// GetAll returns a copy of all session data
func (s *Session) GetAll() map[string]any {
copy := make(map[string]any, len(s.Data))
for k, v := range s.Data {
copy[k] = v
}
return copy
}
// IsExpired checks if the session has expired
func (s *Session) IsExpired() bool {
return time.Now().After(s.Expiry)
}
// UpdateLastUsed updates the last used time
// Only updates if at least 5 seconds have passed since last update
func (s *Session) UpdateLastUsed() {
now := time.Now()
if now.Sub(s.LastUsed) > 5*time.Second {
s.LastUsed = now
// Not marking dirty for LastUsed updates to reduce writes
}
}
// IsDirty returns if the session has unsaved changes
func (s *Session) IsDirty() bool {
return s.dirty
}
// ResetDirty marks the session as clean after saving
func (s *Session) ResetDirty() {
s.dirty = false
}
// SizePlain calculates the size needed to marshal the session
func (s *Session) SizePlain() (size int) {
// ID
size += bstd.SizeString(s.ID)
// Data map
size += bstd.SizeMap(s.Data, bstd.SizeString, func(v any) int {
return sizeAny(v)
})
// Time fields stored as int64 Unix timestamps
size += bstd.SizeInt64() * 4
return size
}
// MarshalPlain serializes the session to binary
func (s *Session) MarshalPlain(n int, b []byte) int {
// ID
n = bstd.MarshalString(n, b, s.ID)
// Data map
n = bstd.MarshalMap(n, b, s.Data, bstd.MarshalString, func(n int, b []byte, v any) int {
return marshalAny(n, b, v)
})
// Time fields as Unix timestamps
n = bstd.MarshalInt64(n, b, s.CreatedAt.Unix())
n = bstd.MarshalInt64(n, b, s.UpdatedAt.Unix())
n = bstd.MarshalInt64(n, b, s.LastUsed.Unix())
n = bstd.MarshalInt64(n, b, s.Expiry.Unix())
return n
}
// UnmarshalPlain deserializes the session from binary
func (s *Session) UnmarshalPlain(n int, b []byte) (int, error) {
var err error
// ID
n, s.ID, err = bstd.UnmarshalString(n, b)
if err != nil {
return n, err
}
// Data map
n, s.Data, err = bstd.UnmarshalMap[string, any](n, b, bstd.UnmarshalString, func(n int, b []byte) (int, any, error) {
return unmarshalAny(n, b)
})
if err != nil {
return n, err
}
// Time fields as Unix timestamps
var timestamp int64
n, timestamp, err = bstd.UnmarshalInt64(n, b)
if err != nil {
return n, err
}
s.CreatedAt = time.Unix(timestamp, 0)
n, timestamp, err = bstd.UnmarshalInt64(n, b)
if err != nil {
return n, err
}
s.UpdatedAt = time.Unix(timestamp, 0)
n, timestamp, err = bstd.UnmarshalInt64(n, b)
if err != nil {
return n, err
}
s.LastUsed = time.Unix(timestamp, 0)
n, timestamp, err = bstd.UnmarshalInt64(n, b)
if err != nil {
return n, err
}
s.Expiry = time.Unix(timestamp, 0)
return n, nil
}
// Marshal serializes the session using benc
func (s *Session) Marshal() ([]byte, error) {
size := s.SizePlain()
data, err := bufPool.Marshal(size, func(b []byte) (n int) {
return s.MarshalPlain(0, b)
})
if err != nil {
return nil, err
}
return data, nil
}
// Unmarshal deserializes a session using benc
func Unmarshal(data []byte) (*Session, error) {
session := GetFromPool()
_, err := session.UnmarshalPlain(0, data)
if err != nil {
ReturnToPool(session)
return nil, err
}
return session, nil
}
// Type identifiers for any values
const (
typeNull byte = 0
typeString byte = 1
typeInt byte = 2
typeFloat byte = 3
typeBool byte = 4
typeBytes byte = 5
)
// sizeAny calculates the size needed for any value
func sizeAny(v any) int {
if v == nil {
return 1 // Just the type byte
}
// 1 byte for type + size of the value
switch val := v.(type) {
case string:
return 1 + bstd.SizeString(val)
case int:
return 1 + bstd.SizeInt64()
case int64:
return 1 + bstd.SizeInt64()
case float64:
return 1 + bstd.SizeFloat64()
case bool:
return 1 + bstd.SizeBool()
case []byte:
return 1 + bstd.SizeBytes(val)
default:
// Convert unhandled types to string
return 1 + bstd.SizeString("unknown")
}
}
// marshalAny serializes any value
func marshalAny(n int, b []byte, v any) int {
if v == nil {
b[n] = typeNull
return n + 1
}
switch val := v.(type) {
case string:
b[n] = typeString
return bstd.MarshalString(n+1, b, val)
case int:
b[n] = typeInt
return bstd.MarshalInt64(n+1, b, int64(val))
case int64:
b[n] = typeInt
return bstd.MarshalInt64(n+1, b, val)
case float64:
b[n] = typeFloat
return bstd.MarshalFloat64(n+1, b, val)
case bool:
b[n] = typeBool
return bstd.MarshalBool(n+1, b, val)
case []byte:
b[n] = typeBytes
return bstd.MarshalBytes(n+1, b, val)
default:
// Convert unhandled types to string
b[n] = typeString
return bstd.MarshalString(n+1, b, "unknown")
}
}
// unmarshalAny deserializes any value
func unmarshalAny(n int, b []byte) (int, any, error) {
if len(b) <= n {
return n, nil, benc.ErrBufTooSmall
}
typeId := b[n]
n++
switch typeId {
case typeNull:
return n, nil, nil
case typeString:
return bstd.UnmarshalString(n, b)
case typeInt:
var val int64
var err error
n, val, err = bstd.UnmarshalInt64(n, b)
return n, val, err
case typeFloat:
var val float64
var err error
n, val, err = bstd.UnmarshalFloat64(n, b)
return n, val, err
case typeBool:
var val bool
var err error
n, val, err = bstd.UnmarshalBool(n, b)
return n, val, err
case typeBytes:
return bstd.UnmarshalBytesCopied(n, b)
default:
// Unknown type, return nil
return n, nil, nil
}
}