435 lines
9.1 KiB
Go
435 lines
9.1 KiB
Go
package sessions
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/deneonet/benc"
|
|
bstd "github.com/deneonet/benc/std"
|
|
)
|
|
|
|
// Session stores data for a single user session
|
|
type Session struct {
|
|
ID string
|
|
Data map[string]any
|
|
CreatedAt time.Time
|
|
UpdatedAt time.Time
|
|
LastUsed time.Time
|
|
Expiry time.Time
|
|
dirty bool
|
|
}
|
|
|
|
var (
|
|
sessionPool = sync.Pool{
|
|
New: func() any {
|
|
return &Session{Data: make(map[string]any, 8)}
|
|
},
|
|
}
|
|
bufPool = benc.NewBufPool(benc.WithBufferSize(4096))
|
|
)
|
|
|
|
// NewSession creates a new session with the given ID
|
|
func NewSession(id string, maxAge int) *Session {
|
|
s := sessionPool.Get().(*Session)
|
|
now := time.Now()
|
|
*s = Session{
|
|
ID: id,
|
|
Data: s.Data, // Reuse map
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
LastUsed: now,
|
|
Expiry: now.Add(time.Duration(maxAge) * time.Second),
|
|
}
|
|
return s
|
|
}
|
|
|
|
// Release returns the session to the pool
|
|
func (s *Session) Release() {
|
|
for k := range s.Data {
|
|
delete(s.Data, k)
|
|
}
|
|
sessionPool.Put(s)
|
|
}
|
|
|
|
// Get returns a deep copy of a value
|
|
func (s *Session) Get(key string) any {
|
|
if v, ok := s.Data[key]; ok {
|
|
return deepCopy(v)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetTable returns a value as a table
|
|
func (s *Session) GetTable(key string) map[string]any {
|
|
if v := s.Get(key); v != nil {
|
|
if t, ok := v.(map[string]any); ok {
|
|
return t
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetAll returns a deep copy of all session data
|
|
func (s *Session) GetAll() map[string]any {
|
|
copy := make(map[string]any, len(s.Data))
|
|
for k, v := range s.Data {
|
|
copy[k] = deepCopy(v)
|
|
}
|
|
return copy
|
|
}
|
|
|
|
// Set stores a value in the session
|
|
func (s *Session) Set(key string, value any) {
|
|
if existing, ok := s.Data[key]; ok && deepEqual(existing, value) {
|
|
return // No change
|
|
}
|
|
s.Data[key] = value
|
|
s.UpdatedAt = time.Now()
|
|
s.dirty = true
|
|
}
|
|
|
|
// SetSafe stores a value with validation
|
|
func (s *Session) SetSafe(key string, value any) error {
|
|
if err := validate(value); err != nil {
|
|
return fmt.Errorf("session.SetSafe: %w", err)
|
|
}
|
|
s.Set(key, value)
|
|
return nil
|
|
}
|
|
|
|
// SetTable is a convenience method for setting table data
|
|
func (s *Session) SetTable(key string, table map[string]any) error {
|
|
return s.SetSafe(key, table)
|
|
}
|
|
|
|
// Delete removes a value from the session
|
|
func (s *Session) Delete(key string) {
|
|
delete(s.Data, key)
|
|
s.UpdatedAt = time.Now()
|
|
s.dirty = true
|
|
}
|
|
|
|
// Clear removes all data from the session
|
|
func (s *Session) Clear() {
|
|
s.Data = make(map[string]any, 8)
|
|
s.UpdatedAt = time.Now()
|
|
s.dirty = true
|
|
}
|
|
|
|
// IsExpired checks if the session has expired
|
|
func (s *Session) IsExpired() bool {
|
|
return time.Now().After(s.Expiry)
|
|
}
|
|
|
|
// UpdateLastUsed updates the last used time
|
|
func (s *Session) UpdateLastUsed() {
|
|
now := time.Now()
|
|
if now.Sub(s.LastUsed) > 5*time.Second {
|
|
s.LastUsed = now
|
|
}
|
|
}
|
|
|
|
// IsDirty returns if the session has unsaved changes
|
|
func (s *Session) IsDirty() bool {
|
|
return s.dirty
|
|
}
|
|
|
|
// ResetDirty marks the session as clean after saving
|
|
func (s *Session) ResetDirty() {
|
|
s.dirty = false
|
|
}
|
|
|
|
// SizePlain calculates the size needed to marshal the session
|
|
func (s *Session) SizePlain() int {
|
|
return bstd.SizeString(s.ID) +
|
|
bstd.SizeMap(s.Data, bstd.SizeString, sizeAny) +
|
|
bstd.SizeInt64()*4
|
|
}
|
|
|
|
// MarshalPlain serializes the session to binary
|
|
func (s *Session) MarshalPlain(n int, b []byte) int {
|
|
n = bstd.MarshalString(n, b, s.ID)
|
|
n = bstd.MarshalMap(n, b, s.Data, bstd.MarshalString, marshalAny)
|
|
n = bstd.MarshalInt64(n, b, s.CreatedAt.Unix())
|
|
n = bstd.MarshalInt64(n, b, s.UpdatedAt.Unix())
|
|
n = bstd.MarshalInt64(n, b, s.LastUsed.Unix())
|
|
return bstd.MarshalInt64(n, b, s.Expiry.Unix())
|
|
}
|
|
|
|
// UnmarshalPlain deserializes the session from binary
|
|
func (s *Session) UnmarshalPlain(n int, b []byte) (int, error) {
|
|
var err error
|
|
n, s.ID, err = bstd.UnmarshalString(n, b)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
|
|
n, s.Data, err = bstd.UnmarshalMap[string, any](n, b, bstd.UnmarshalString, unmarshalAny)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
|
|
var ts int64
|
|
for _, t := range []*time.Time{&s.CreatedAt, &s.UpdatedAt, &s.LastUsed, &s.Expiry} {
|
|
n, ts, err = bstd.UnmarshalInt64(n, b)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
*t = time.Unix(ts, 0)
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
// Marshal serializes the session using benc
|
|
func (s *Session) Marshal() ([]byte, error) {
|
|
return bufPool.Marshal(s.SizePlain(), func(b []byte) int {
|
|
return s.MarshalPlain(0, b)
|
|
})
|
|
}
|
|
|
|
// Unmarshal deserializes a session using benc
|
|
func Unmarshal(data []byte) (*Session, error) {
|
|
s := sessionPool.Get().(*Session)
|
|
if _, err := s.UnmarshalPlain(0, data); err != nil {
|
|
s.Release()
|
|
return nil, err
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
// Type identifiers
|
|
const (
|
|
typeNull byte = 0
|
|
typeString byte = 1
|
|
typeInt byte = 2
|
|
typeFloat byte = 3
|
|
typeBool byte = 4
|
|
typeBytes byte = 5
|
|
typeTable byte = 6
|
|
typeArray byte = 7
|
|
)
|
|
|
|
// sizeAny calculates the size needed for any value
|
|
func sizeAny(v any) int {
|
|
if v == nil {
|
|
return 1
|
|
}
|
|
|
|
size := 1 // type byte
|
|
switch v := v.(type) {
|
|
case string:
|
|
size += bstd.SizeString(v)
|
|
case int:
|
|
size += bstd.SizeInt64()
|
|
case int64:
|
|
size += bstd.SizeInt64()
|
|
case float64:
|
|
size += bstd.SizeFloat64()
|
|
case bool:
|
|
size += bstd.SizeBool()
|
|
case []byte:
|
|
size += bstd.SizeBytes(v)
|
|
case map[string]any:
|
|
size += bstd.SizeMap(v, bstd.SizeString, sizeAny)
|
|
case []any:
|
|
size += bstd.SizeSlice(v, sizeAny)
|
|
default:
|
|
size += bstd.SizeString("unknown")
|
|
}
|
|
return size
|
|
}
|
|
|
|
// marshalAny serializes any value
|
|
func marshalAny(n int, b []byte, v any) int {
|
|
if v == nil {
|
|
b[n] = typeNull
|
|
return n + 1
|
|
}
|
|
|
|
switch v := v.(type) {
|
|
case string:
|
|
b[n] = typeString
|
|
return bstd.MarshalString(n+1, b, v)
|
|
case int:
|
|
b[n] = typeInt
|
|
return bstd.MarshalInt64(n+1, b, int64(v))
|
|
case int64:
|
|
b[n] = typeInt
|
|
return bstd.MarshalInt64(n+1, b, v)
|
|
case float64:
|
|
b[n] = typeFloat
|
|
return bstd.MarshalFloat64(n+1, b, v)
|
|
case bool:
|
|
b[n] = typeBool
|
|
return bstd.MarshalBool(n+1, b, v)
|
|
case []byte:
|
|
b[n] = typeBytes
|
|
return bstd.MarshalBytes(n+1, b, v)
|
|
case map[string]any:
|
|
b[n] = typeTable
|
|
return bstd.MarshalMap(n+1, b, v, bstd.MarshalString, marshalAny)
|
|
case []any:
|
|
b[n] = typeArray
|
|
return bstd.MarshalSlice(n+1, b, v, marshalAny)
|
|
default:
|
|
b[n] = typeString
|
|
return bstd.MarshalString(n+1, b, "unknown")
|
|
}
|
|
}
|
|
|
|
// unmarshalAny deserializes any value
|
|
func unmarshalAny(n int, b []byte) (int, any, error) {
|
|
if len(b) <= n {
|
|
return n, nil, benc.ErrBufTooSmall
|
|
}
|
|
|
|
switch b[n] {
|
|
case typeNull:
|
|
return n + 1, nil, nil
|
|
case typeString:
|
|
return bstd.UnmarshalString(n+1, b)
|
|
case typeInt:
|
|
n, v, err := bstd.UnmarshalInt64(n+1, b)
|
|
return n, v, err
|
|
case typeFloat:
|
|
return bstd.UnmarshalFloat64(n+1, b)
|
|
case typeBool:
|
|
return bstd.UnmarshalBool(n+1, b)
|
|
case typeBytes:
|
|
return bstd.UnmarshalBytesCopied(n+1, b)
|
|
case typeTable:
|
|
return bstd.UnmarshalMap[string, any](n+1, b, bstd.UnmarshalString, unmarshalAny)
|
|
case typeArray:
|
|
return bstd.UnmarshalSlice[any](n+1, b, unmarshalAny)
|
|
default:
|
|
return n + 1, nil, nil
|
|
}
|
|
}
|
|
|
|
// deepCopy creates a deep copy of any value
|
|
func deepCopy(v any) any {
|
|
switch v := v.(type) {
|
|
case map[string]any:
|
|
cp := make(map[string]any, len(v))
|
|
for k, val := range v {
|
|
cp[k] = deepCopy(val)
|
|
}
|
|
return cp
|
|
case []any:
|
|
cp := make([]any, len(v))
|
|
for i, val := range v {
|
|
cp[i] = deepCopy(val)
|
|
}
|
|
return cp
|
|
default:
|
|
return v
|
|
}
|
|
}
|
|
|
|
// validate ensures a value can be safely serialized
|
|
func validate(v any) error {
|
|
switch v := v.(type) {
|
|
case nil, string, int, int64, float64, bool, []byte:
|
|
return nil
|
|
case map[string]any:
|
|
for k, val := range v {
|
|
if err := validate(val); err != nil {
|
|
return fmt.Errorf("invalid value for key %q: %w", k, err)
|
|
}
|
|
}
|
|
case []any:
|
|
for i, val := range v {
|
|
if err := validate(val); err != nil {
|
|
return fmt.Errorf("invalid value at index %d: %w", i, err)
|
|
}
|
|
}
|
|
default:
|
|
return fmt.Errorf("unsupported type: %T", v)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// deepEqual efficiently compares two values for deep equality
|
|
func deepEqual(a, b any) bool {
|
|
if a == b {
|
|
return true
|
|
}
|
|
|
|
if a == nil || b == nil {
|
|
return false
|
|
}
|
|
|
|
switch va := a.(type) {
|
|
case string:
|
|
if vb, ok := b.(string); ok {
|
|
return va == vb
|
|
}
|
|
case int:
|
|
if vb, ok := b.(int); ok {
|
|
return va == vb
|
|
}
|
|
if vb, ok := b.(int64); ok {
|
|
return int64(va) == vb
|
|
}
|
|
case int64:
|
|
if vb, ok := b.(int64); ok {
|
|
return va == vb
|
|
}
|
|
if vb, ok := b.(int); ok {
|
|
return va == int64(vb)
|
|
}
|
|
case float64:
|
|
if vb, ok := b.(float64); ok {
|
|
return va == vb
|
|
}
|
|
case bool:
|
|
if vb, ok := b.(bool); ok {
|
|
return va == vb
|
|
}
|
|
case []byte:
|
|
if vb, ok := b.([]byte); ok {
|
|
if len(va) != len(vb) {
|
|
return false
|
|
}
|
|
for i, v := range va {
|
|
if v != vb[i] {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
case map[string]any:
|
|
if vb, ok := b.(map[string]any); ok {
|
|
if len(va) != len(vb) {
|
|
return false
|
|
}
|
|
for k, v := range va {
|
|
if bv, exists := vb[k]; !exists || !deepEqual(v, bv) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
case []any:
|
|
if vb, ok := b.([]any); ok {
|
|
if len(va) != len(vb) {
|
|
return false
|
|
}
|
|
for i, v := range va {
|
|
if !deepEqual(v, vb[i]) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// IsEmpty returns true if the session has no data
|
|
func (s *Session) IsEmpty() bool {
|
|
return len(s.Data) == 0
|
|
}
|