Moonshark/sessions/session.go

430 lines
9.0 KiB
Go

package sessions
import (
"fmt"
"sync"
"time"
"github.com/deneonet/benc"
bstd "github.com/deneonet/benc/std"
)
// Session stores data for a single user session
type Session struct {
ID string
Data map[string]any
CreatedAt time.Time
UpdatedAt time.Time
LastUsed time.Time
Expiry time.Time
dirty bool
}
var (
sessionPool = sync.Pool{
New: func() any {
return &Session{Data: make(map[string]any, 8)}
},
}
bufPool = benc.NewBufPool(benc.WithBufferSize(4096))
)
// NewSession creates a new session with the given ID
func NewSession(id string, maxAge int) *Session {
s := sessionPool.Get().(*Session)
now := time.Now()
*s = Session{
ID: id,
Data: s.Data, // Reuse map
CreatedAt: now,
UpdatedAt: now,
LastUsed: now,
Expiry: now.Add(time.Duration(maxAge) * time.Second),
}
return s
}
// Release returns the session to the pool
func (s *Session) Release() {
for k := range s.Data {
delete(s.Data, k)
}
sessionPool.Put(s)
}
// Get returns a deep copy of a value
func (s *Session) Get(key string) any {
if v, ok := s.Data[key]; ok {
return deepCopy(v)
}
return nil
}
// GetTable returns a value as a table
func (s *Session) GetTable(key string) map[string]any {
if v := s.Get(key); v != nil {
if t, ok := v.(map[string]any); ok {
return t
}
}
return nil
}
// GetAll returns a deep copy of all session data
func (s *Session) GetAll() map[string]any {
copy := make(map[string]any, len(s.Data))
for k, v := range s.Data {
copy[k] = deepCopy(v)
}
return copy
}
// Set stores a value in the session
func (s *Session) Set(key string, value any) {
if existing, ok := s.Data[key]; ok && deepEqual(existing, value) {
return // No change
}
s.Data[key] = value
s.UpdatedAt = time.Now()
s.dirty = true
}
// SetSafe stores a value with validation
func (s *Session) SetSafe(key string, value any) error {
if err := validate(value); err != nil {
return fmt.Errorf("session.SetSafe: %w", err)
}
s.Set(key, value)
return nil
}
// SetTable is a convenience method for setting table data
func (s *Session) SetTable(key string, table map[string]any) error {
return s.SetSafe(key, table)
}
// Delete removes a value from the session
func (s *Session) Delete(key string) {
delete(s.Data, key)
s.UpdatedAt = time.Now()
s.dirty = true
}
// Clear removes all data from the session
func (s *Session) Clear() {
s.Data = make(map[string]any, 8)
s.UpdatedAt = time.Now()
s.dirty = true
}
// IsExpired checks if the session has expired
func (s *Session) IsExpired() bool {
return time.Now().After(s.Expiry)
}
// UpdateLastUsed updates the last used time
func (s *Session) UpdateLastUsed() {
now := time.Now()
if now.Sub(s.LastUsed) > 5*time.Second {
s.LastUsed = now
}
}
// IsDirty returns if the session has unsaved changes
func (s *Session) IsDirty() bool {
return s.dirty
}
// ResetDirty marks the session as clean after saving
func (s *Session) ResetDirty() {
s.dirty = false
}
// SizePlain calculates the size needed to marshal the session
func (s *Session) SizePlain() int {
return bstd.SizeString(s.ID) +
bstd.SizeMap(s.Data, bstd.SizeString, sizeAny) +
bstd.SizeInt64()*4
}
// MarshalPlain serializes the session to binary
func (s *Session) MarshalPlain(n int, b []byte) int {
n = bstd.MarshalString(n, b, s.ID)
n = bstd.MarshalMap(n, b, s.Data, bstd.MarshalString, marshalAny)
n = bstd.MarshalInt64(n, b, s.CreatedAt.Unix())
n = bstd.MarshalInt64(n, b, s.UpdatedAt.Unix())
n = bstd.MarshalInt64(n, b, s.LastUsed.Unix())
return bstd.MarshalInt64(n, b, s.Expiry.Unix())
}
// UnmarshalPlain deserializes the session from binary
func (s *Session) UnmarshalPlain(n int, b []byte) (int, error) {
var err error
n, s.ID, err = bstd.UnmarshalString(n, b)
if err != nil {
return n, err
}
n, s.Data, err = bstd.UnmarshalMap[string, any](n, b, bstd.UnmarshalString, unmarshalAny)
if err != nil {
return n, err
}
var ts int64
for _, t := range []*time.Time{&s.CreatedAt, &s.UpdatedAt, &s.LastUsed, &s.Expiry} {
n, ts, err = bstd.UnmarshalInt64(n, b)
if err != nil {
return n, err
}
*t = time.Unix(ts, 0)
}
return n, nil
}
// Marshal serializes the session using benc
func (s *Session) Marshal() ([]byte, error) {
return bufPool.Marshal(s.SizePlain(), func(b []byte) int {
return s.MarshalPlain(0, b)
})
}
// Unmarshal deserializes a session using benc
func Unmarshal(data []byte) (*Session, error) {
s := sessionPool.Get().(*Session)
if _, err := s.UnmarshalPlain(0, data); err != nil {
s.Release()
return nil, err
}
return s, nil
}
// Type identifiers
const (
typeNull byte = 0
typeString byte = 1
typeInt byte = 2
typeFloat byte = 3
typeBool byte = 4
typeBytes byte = 5
typeTable byte = 6
typeArray byte = 7
)
// sizeAny calculates the size needed for any value
func sizeAny(v any) int {
if v == nil {
return 1
}
size := 1 // type byte
switch v := v.(type) {
case string:
size += bstd.SizeString(v)
case int:
size += bstd.SizeInt64()
case int64:
size += bstd.SizeInt64()
case float64:
size += bstd.SizeFloat64()
case bool:
size += bstd.SizeBool()
case []byte:
size += bstd.SizeBytes(v)
case map[string]any:
size += bstd.SizeMap(v, bstd.SizeString, sizeAny)
case []any:
size += bstd.SizeSlice(v, sizeAny)
default:
size += bstd.SizeString("unknown")
}
return size
}
// marshalAny serializes any value
func marshalAny(n int, b []byte, v any) int {
if v == nil {
b[n] = typeNull
return n + 1
}
switch v := v.(type) {
case string:
b[n] = typeString
return bstd.MarshalString(n+1, b, v)
case int:
b[n] = typeInt
return bstd.MarshalInt64(n+1, b, int64(v))
case int64:
b[n] = typeInt
return bstd.MarshalInt64(n+1, b, v)
case float64:
b[n] = typeFloat
return bstd.MarshalFloat64(n+1, b, v)
case bool:
b[n] = typeBool
return bstd.MarshalBool(n+1, b, v)
case []byte:
b[n] = typeBytes
return bstd.MarshalBytes(n+1, b, v)
case map[string]any:
b[n] = typeTable
return bstd.MarshalMap(n+1, b, v, bstd.MarshalString, marshalAny)
case []any:
b[n] = typeArray
return bstd.MarshalSlice(n+1, b, v, marshalAny)
default:
b[n] = typeString
return bstd.MarshalString(n+1, b, "unknown")
}
}
// unmarshalAny deserializes any value
func unmarshalAny(n int, b []byte) (int, any, error) {
if len(b) <= n {
return n, nil, benc.ErrBufTooSmall
}
switch b[n] {
case typeNull:
return n + 1, nil, nil
case typeString:
return bstd.UnmarshalString(n+1, b)
case typeInt:
n, v, err := bstd.UnmarshalInt64(n+1, b)
return n, v, err
case typeFloat:
return bstd.UnmarshalFloat64(n+1, b)
case typeBool:
return bstd.UnmarshalBool(n+1, b)
case typeBytes:
return bstd.UnmarshalBytesCopied(n+1, b)
case typeTable:
return bstd.UnmarshalMap[string, any](n+1, b, bstd.UnmarshalString, unmarshalAny)
case typeArray:
return bstd.UnmarshalSlice[any](n+1, b, unmarshalAny)
default:
return n + 1, nil, nil
}
}
// deepCopy creates a deep copy of any value
func deepCopy(v any) any {
switch v := v.(type) {
case map[string]any:
cp := make(map[string]any, len(v))
for k, val := range v {
cp[k] = deepCopy(val)
}
return cp
case []any:
cp := make([]any, len(v))
for i, val := range v {
cp[i] = deepCopy(val)
}
return cp
default:
return v
}
}
// validate ensures a value can be safely serialized
func validate(v any) error {
switch v := v.(type) {
case nil, string, int, int64, float64, bool, []byte:
return nil
case map[string]any:
for k, val := range v {
if err := validate(val); err != nil {
return fmt.Errorf("invalid value for key %q: %w", k, err)
}
}
case []any:
for i, val := range v {
if err := validate(val); err != nil {
return fmt.Errorf("invalid value at index %d: %w", i, err)
}
}
default:
return fmt.Errorf("unsupported type: %T", v)
}
return nil
}
// deepEqual efficiently compares two values for deep equality
func deepEqual(a, b any) bool {
if a == b {
return true
}
if a == nil || b == nil {
return false
}
switch va := a.(type) {
case string:
if vb, ok := b.(string); ok {
return va == vb
}
case int:
if vb, ok := b.(int); ok {
return va == vb
}
if vb, ok := b.(int64); ok {
return int64(va) == vb
}
case int64:
if vb, ok := b.(int64); ok {
return va == vb
}
if vb, ok := b.(int); ok {
return va == int64(vb)
}
case float64:
if vb, ok := b.(float64); ok {
return va == vb
}
case bool:
if vb, ok := b.(bool); ok {
return va == vb
}
case []byte:
if vb, ok := b.([]byte); ok {
if len(va) != len(vb) {
return false
}
for i, v := range va {
if v != vb[i] {
return false
}
}
return true
}
case map[string]any:
if vb, ok := b.(map[string]any); ok {
if len(va) != len(vb) {
return false
}
for k, v := range va {
if bv, exists := vb[k]; !exists || !deepEqual(v, bv) {
return false
}
}
return true
}
case []any:
if vb, ok := b.([]any); ok {
if len(va) != len(vb) {
return false
}
for i, v := range va {
if !deepEqual(v, vb[i]) {
return false
}
}
return true
}
}
return false
}