add table support to sessions, fix root lua path, optimize sesison manager

This commit is contained in:
Sky Johnson 2025-05-26 12:56:20 -05:00
parent 82470b35a0
commit e4cd490f0f
4 changed files with 228 additions and 274 deletions

View File

@ -164,7 +164,8 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
session := s.sessionManager.GetSessionFromRequest(ctx) session := s.sessionManager.GetSessionFromRequest(ctx)
sessionMap["id"] = session.ID sessionMap["id"] = session.ID
sessionMap["data"] = session.Data
sessionMap["data"] = session.GetAll() // This now returns a deep copy
luaCtx.Set("method", method) luaCtx.Set("method", method)
luaCtx.Set("path", path) luaCtx.Set("path", path)
@ -209,11 +210,12 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
delete(response.SessionData, "__clear_all") delete(response.SessionData, "__clear_all")
} }
// Apply session changes - now supports nested tables
for k, v := range response.SessionData { for k, v := range response.SessionData {
if v == "__SESSION_DELETE_MARKER__" { if v == "__SESSION_DELETE_MARKER__" {
session.Delete(k) session.Delete(k)
} else { } else {
session.Set(k, v) session.Set(k, v) // This will handle tables through marshalling
} }
} }

View File

@ -485,6 +485,15 @@ func (r *LuaRouter) Match(method, path string, params *Params) (*node, bool) {
// matchPath recursively matches a path against the routing tree // matchPath recursively matches a path against the routing tree
func (r *LuaRouter) matchPath(current *node, segments []string, params *Params, depth int) (*node, bool) { func (r *LuaRouter) matchPath(current *node, segments []string, params *Params, depth int) (*node, bool) {
// Filter empty segments
filteredSegments := segments[:0]
for _, segment := range segments {
if segment != "" {
filteredSegments = append(filteredSegments, segment)
}
}
segments = filteredSegments
if len(segments) == 0 { if len(segments) == 0 {
if current.handler != "" { if current.handler != "" {
return current, true return current, true

View File

@ -20,40 +20,25 @@ const (
// SessionManager handles multiple sessions // SessionManager handles multiple sessions
type SessionManager struct { type SessionManager struct {
cache *fastcache.Cache cache *fastcache.Cache
maxSessions int
cookieName string cookieName string
cookiePath string cookiePath string
cookieDomain string cookieDomain string
cookieSecure bool cookieSecure bool
cookieHTTPOnly bool cookieHTTPOnly bool
cookieMaxAge int cookieMaxAge int
cookieMu sync.RWMutex // Only cookie options need a mutex cookieMu sync.RWMutex
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
cleanupDone chan struct{} cleanupDone chan struct{}
} }
// InitializeSessionPool pre-allocates session objects
func InitializeSessionPool(size int) {
for range size {
session := &Session{
Data: make(map[string]any, 8),
}
ReturnToPool(session)
}
}
// NewSessionManager creates a new session manager // NewSessionManager creates a new session manager
func NewSessionManager(maxSessions int) *SessionManager { func NewSessionManager(maxSessions int) *SessionManager {
if maxSessions <= 0 { if maxSessions <= 0 {
maxSessions = DefaultMaxSessions maxSessions = DefaultMaxSessions
} }
// Estimate max memory: ~4KB per session × maxSessions
maxBytes := maxSessions * 4096
sm := &SessionManager{ sm := &SessionManager{
cache: fastcache.New(maxBytes), cache: fastcache.New(maxSessions * 4096),
maxSessions: maxSessions,
cookieName: DefaultCookieName, cookieName: DefaultCookieName,
cookiePath: DefaultCookiePath, cookiePath: DefaultCookiePath,
cookieHTTPOnly: true, cookieHTTPOnly: true,
@ -61,10 +46,12 @@ func NewSessionManager(maxSessions int) *SessionManager {
cleanupDone: make(chan struct{}), cleanupDone: make(chan struct{}),
} }
// Pre-allocate session objects for common pool size // Pre-populate session pool
InitializeSessionPool(100) // Adjust based on expected concurrent requests for i := 0; i < 100; i++ {
s := NewSession("", 0)
s.Release()
}
// Start periodic cleanup
sm.cleanupTicker = time.NewTicker(CleanupInterval) sm.cleanupTicker = time.NewTicker(CleanupInterval)
go sm.cleanupRoutine() go sm.cleanupRoutine()
@ -76,7 +63,6 @@ func (sm *SessionManager) Stop() {
close(sm.cleanupDone) close(sm.cleanupDone)
} }
// cleanupRoutine periodically removes expired sessions
func (sm *SessionManager) cleanupRoutine() { func (sm *SessionManager) cleanupRoutine() {
for { for {
select { select {
@ -89,124 +75,80 @@ func (sm *SessionManager) cleanupRoutine() {
} }
} }
// generateSessionID creates a random session ID
func generateSessionID() string {
id, _ := gonanoid.New()
return id
}
// GetSession retrieves a session by ID, or creates a new one if it doesn't exist // GetSession retrieves a session by ID, or creates a new one if it doesn't exist
func (sm *SessionManager) GetSession(id string) *Session { func (sm *SessionManager) GetSession(id string) *Session {
// Try to get an existing session
if id != "" { if id != "" {
data := sm.cache.Get(nil, []byte(id)) if data := sm.cache.Get(nil, []byte(id)); len(data) > 0 {
if len(data) > 0 { if s, err := Unmarshal(data); err == nil && !s.IsExpired() {
session, err := Unmarshal(data) s.UpdateLastUsed()
if err == nil && !session.IsExpired() { s.ResetDirty()
session.UpdateLastUsed() return s
session.ResetDirty() // Start clean
return session
} }
// Session expired or corrupt, remove it
sm.cache.Del([]byte(id)) sm.cache.Del([]byte(id))
} }
} }
// Create a new session
return sm.CreateSession() return sm.CreateSession()
} }
// CreateSession generates a new session with a unique ID // CreateSession generates a new session with a unique ID
func (sm *SessionManager) CreateSession() *Session { func (sm *SessionManager) CreateSession() *Session {
id := generateSessionID() id, _ := gonanoid.New()
// Ensure ID uniqueness // Ensure uniqueness (max 3 attempts)
attempts := 0 for i := 0; i < 3 && sm.cache.Has([]byte(id)); i++ {
for attempts < 3 { id, _ = gonanoid.New()
if sm.cache.Has([]byte(id)) {
id = generateSessionID()
attempts++
} else {
break
}
} }
session := NewSession(id, sm.cookieMaxAge) s := NewSession(id, sm.cookieMaxAge)
if data, err := s.Marshal(); err == nil {
// Serialize and store the session
if data, err := session.Marshal(); err == nil {
sm.cache.Set([]byte(id), data) sm.cache.Set([]byte(id), data)
} }
s.ResetDirty()
session.ResetDirty() // Start clean return s
return session
} }
// DestroySession removes a session // DestroySession removes a session
func (sm *SessionManager) DestroySession(id string) { func (sm *SessionManager) DestroySession(id string) {
// Get and clean session from cache before deleting if data := sm.cache.Get(nil, []byte(id)); len(data) > 0 {
data := sm.cache.Get(nil, []byte(id)) if s, err := Unmarshal(data); err == nil {
if len(data) > 0 { s.Release()
if session, err := Unmarshal(data); err == nil {
ReturnToPool(session)
} }
} }
sm.cache.Del([]byte(id)) sm.cache.Del([]byte(id))
} }
// CleanupExpired removes all expired sessions // CleanupExpired removes all expired sessions
// Note: fastcache doesn't provide iteration, so we can't clean all expired sessions
// This is a limitation of this implementation
func (sm *SessionManager) CleanupExpired() int { func (sm *SessionManager) CleanupExpired() int {
// No way to iterate through all keys in fastcache // fastcache doesn't support iteration
// We'd need to track expiring sessions separately
return 0 return 0
} }
// SetCookieOptions configures cookie parameters // SetCookieOptions configures cookie parameters
func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) { func (sm *SessionManager) SetCookieOptions(name, path, domain string, secure, httpOnly bool, maxAge int) {
sm.cookieMu.Lock() sm.cookieMu.Lock()
defer sm.cookieMu.Unlock()
sm.cookieName = name sm.cookieName = name
sm.cookiePath = path sm.cookiePath = path
sm.cookieDomain = domain sm.cookieDomain = domain
sm.cookieSecure = secure sm.cookieSecure = secure
sm.cookieHTTPOnly = httpOnly sm.cookieHTTPOnly = httpOnly
sm.cookieMaxAge = maxAge sm.cookieMaxAge = maxAge
sm.cookieMu.Unlock()
} }
// GetSessionFromRequest extracts the session from a request // GetSessionFromRequest extracts the session from a request
func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session { func (sm *SessionManager) GetSessionFromRequest(ctx *fasthttp.RequestCtx) *Session {
sm.cookieMu.RLock() sm.cookieMu.RLock()
cookieName := sm.cookieName name := sm.cookieName
sm.cookieMu.RUnlock() sm.cookieMu.RUnlock()
cookie := ctx.Request.Header.Cookie(cookieName) if cookie := ctx.Request.Header.Cookie(name); len(cookie) > 0 {
if len(cookie) == 0 {
return sm.CreateSession()
}
return sm.GetSession(string(cookie)) return sm.GetSession(string(cookie))
}
return sm.CreateSession()
} }
// ApplySessionCookie adds the session cookie to the response // ApplySessionCookie adds the session cookie to the response
func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) { func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *Session) {
cookie := fasthttp.AcquireCookie()
defer fasthttp.ReleaseCookie(cookie)
// Get cookie options with minimal lock time
sm.cookieMu.RLock()
cookieName := sm.cookieName
cookiePath := sm.cookiePath
cookieDomain := sm.cookieDomain
cookieSecure := sm.cookieSecure
cookieHTTPOnly := sm.cookieHTTPOnly
cookieMaxAge := sm.cookieMaxAge
sm.cookieMu.RUnlock()
// Store updated session only if it has changes
if session.IsDirty() { if session.IsDirty() {
if data, err := session.Marshal(); err == nil { if data, err := session.Marshal(); err == nil {
sm.cache.Set([]byte(session.ID), data) sm.cache.Set([]byte(session.ID), data)
@ -214,18 +156,21 @@ func (sm *SessionManager) ApplySessionCookie(ctx *fasthttp.RequestCtx, session *
session.ResetDirty() session.ResetDirty()
} }
cookie.SetKey(cookieName) cookie := fasthttp.AcquireCookie()
cookie.SetValue(session.ID) defer fasthttp.ReleaseCookie(cookie)
cookie.SetPath(cookiePath)
cookie.SetHTTPOnly(cookieHTTPOnly)
cookie.SetMaxAge(cookieMaxAge)
if cookieDomain != "" { sm.cookieMu.RLock()
cookie.SetDomain(cookieDomain) cookie.SetKey(sm.cookieName)
cookie.SetPath(sm.cookiePath)
cookie.SetHTTPOnly(sm.cookieHTTPOnly)
cookie.SetMaxAge(sm.cookieMaxAge)
if sm.cookieDomain != "" {
cookie.SetDomain(sm.cookieDomain)
} }
cookie.SetSecure(sm.cookieSecure)
sm.cookieMu.RUnlock()
cookie.SetSecure(cookieSecure) cookie.SetValue(session.ID)
ctx.Response.Header.SetCookie(cookie) ctx.Response.Header.SetCookie(cookie)
} }
@ -244,9 +189,6 @@ func (sm *SessionManager) CookieOptions() map[string]any {
} }
} }
// GlobalSessionManager is the default session manager instance
var GlobalSessionManager = NewSessionManager(DefaultMaxSessions)
// GetCacheStats returns statistics about the session cache // GetCacheStats returns statistics about the session cache
func (sm *SessionManager) GetCacheStats() map[string]uint64 { func (sm *SessionManager) GetCacheStats() map[string]uint64 {
if sm == nil || sm.cache == nil { if sm == nil || sm.cache == nil {
@ -265,3 +207,6 @@ func (sm *SessionManager) GetCacheStats() map[string]uint64 {
"misses": stats.Misses, "misses": stats.Misses,
} }
} }
// GlobalSessionManager is the default session manager instance
var GlobalSessionManager = NewSessionManager(DefaultMaxSessions)

View File

@ -1,6 +1,7 @@
package sessions package sessions
import ( import (
"fmt"
"sync" "sync"
"time" "time"
@ -16,67 +17,66 @@ type Session struct {
UpdatedAt time.Time UpdatedAt time.Time
LastUsed time.Time LastUsed time.Time
Expiry time.Time Expiry time.Time
dirty bool // Tracks if session has changes, not serialized dirty bool
} }
// Session pool to reduce allocations var (
var sessionPool = sync.Pool{ sessionPool = sync.Pool{
New: func() any { New: func() any {
return &Session{ return &Session{Data: make(map[string]any, 8)}
Data: make(map[string]any, 8),
}
}, },
}
// BufPool for reusing serialization buffers
var bufPool = benc.NewBufPool(benc.WithBufferSize(4096))
// GetFromPool retrieves a session from the pool
func GetFromPool() *Session {
return sessionPool.Get().(*Session)
}
// ReturnToPool returns a session to the pool after cleaning it
func ReturnToPool(s *Session) {
if s == nil {
return
} }
bufPool = benc.NewBufPool(benc.WithBufferSize(4096))
// Clean the session for reuse )
s.ID = ""
for k := range s.Data {
delete(s.Data, k)
}
s.CreatedAt = time.Time{}
s.UpdatedAt = time.Time{}
s.LastUsed = time.Time{}
s.Expiry = time.Time{}
s.dirty = false
sessionPool.Put(s)
}
// NewSession creates a new session with the given ID // NewSession creates a new session with the given ID
func NewSession(id string, maxAge int) *Session { func NewSession(id string, maxAge int) *Session {
s := sessionPool.Get().(*Session)
now := time.Now() now := time.Now()
*s = Session{
// Get from pool or create new ID: id,
session := GetFromPool() Data: s.Data, // Reuse map
CreatedAt: now,
// Initialize UpdatedAt: now,
session.ID = id LastUsed: now,
session.CreatedAt = now Expiry: now.Add(time.Duration(maxAge) * time.Second),
session.UpdatedAt = now }
session.LastUsed = now return s
session.Expiry = now.Add(time.Duration(maxAge) * time.Second)
session.dirty = false
return session
} }
// Get retrieves a value from the session // Release returns the session to the pool
func (s *Session) Release() {
for k := range s.Data {
delete(s.Data, k)
}
sessionPool.Put(s)
}
// Get returns a deep copy of a value
func (s *Session) Get(key string) any { func (s *Session) Get(key string) any {
return s.Data[key] if v, ok := s.Data[key]; ok {
return deepCopy(v)
}
return nil
}
// GetTable returns a value as a table
func (s *Session) GetTable(key string) map[string]any {
if v := s.Get(key); v != nil {
if t, ok := v.(map[string]any); ok {
return t
}
}
return nil
}
// GetAll returns a deep copy of all session data
func (s *Session) GetAll() map[string]any {
copy := make(map[string]any, len(s.Data))
for k, v := range s.Data {
copy[k] = deepCopy(v)
}
return copy
} }
// Set stores a value in the session // Set stores a value in the session
@ -86,6 +86,20 @@ func (s *Session) Set(key string, value any) {
s.dirty = true s.dirty = true
} }
// SetSafe stores a value with validation
func (s *Session) SetSafe(key string, value any) error {
if err := validate(value); err != nil {
return fmt.Errorf("session.SetSafe: %w", err)
}
s.Set(key, value)
return nil
}
// SetTable is a convenience method for setting table data
func (s *Session) SetTable(key string, table map[string]any) error {
return s.SetSafe(key, table)
}
// Delete removes a value from the session // Delete removes a value from the session
func (s *Session) Delete(key string) { func (s *Session) Delete(key string) {
delete(s.Data, key) delete(s.Data, key)
@ -100,27 +114,16 @@ func (s *Session) Clear() {
s.dirty = true s.dirty = true
} }
// GetAll returns a copy of all session data
func (s *Session) GetAll() map[string]any {
copy := make(map[string]any, len(s.Data))
for k, v := range s.Data {
copy[k] = v
}
return copy
}
// IsExpired checks if the session has expired // IsExpired checks if the session has expired
func (s *Session) IsExpired() bool { func (s *Session) IsExpired() bool {
return time.Now().After(s.Expiry) return time.Now().After(s.Expiry)
} }
// UpdateLastUsed updates the last used time // UpdateLastUsed updates the last used time
// Only updates if at least 5 seconds have passed since last update
func (s *Session) UpdateLastUsed() { func (s *Session) UpdateLastUsed() {
now := time.Now() now := time.Now()
if now.Sub(s.LastUsed) > 5*time.Second { if now.Sub(s.LastUsed) > 5*time.Second {
s.LastUsed = now s.LastUsed = now
// Not marking dirty for LastUsed updates to reduce writes
} }
} }
@ -135,115 +138,64 @@ func (s *Session) ResetDirty() {
} }
// SizePlain calculates the size needed to marshal the session // SizePlain calculates the size needed to marshal the session
func (s *Session) SizePlain() (size int) { func (s *Session) SizePlain() int {
// ID return bstd.SizeString(s.ID) +
size += bstd.SizeString(s.ID) bstd.SizeMap(s.Data, bstd.SizeString, sizeAny) +
bstd.SizeInt64()*4
// Data map
size += bstd.SizeMap(s.Data, bstd.SizeString, func(v any) int {
return sizeAny(v)
})
// Time fields stored as int64 Unix timestamps
size += bstd.SizeInt64() * 4
return size
} }
// MarshalPlain serializes the session to binary // MarshalPlain serializes the session to binary
func (s *Session) MarshalPlain(n int, b []byte) int { func (s *Session) MarshalPlain(n int, b []byte) int {
// ID
n = bstd.MarshalString(n, b, s.ID) n = bstd.MarshalString(n, b, s.ID)
n = bstd.MarshalMap(n, b, s.Data, bstd.MarshalString, marshalAny)
// Data map
n = bstd.MarshalMap(n, b, s.Data, bstd.MarshalString, func(n int, b []byte, v any) int {
return marshalAny(n, b, v)
})
// Time fields as Unix timestamps
n = bstd.MarshalInt64(n, b, s.CreatedAt.Unix()) n = bstd.MarshalInt64(n, b, s.CreatedAt.Unix())
n = bstd.MarshalInt64(n, b, s.UpdatedAt.Unix()) n = bstd.MarshalInt64(n, b, s.UpdatedAt.Unix())
n = bstd.MarshalInt64(n, b, s.LastUsed.Unix()) n = bstd.MarshalInt64(n, b, s.LastUsed.Unix())
n = bstd.MarshalInt64(n, b, s.Expiry.Unix()) return bstd.MarshalInt64(n, b, s.Expiry.Unix())
return n
} }
// UnmarshalPlain deserializes the session from binary // UnmarshalPlain deserializes the session from binary
func (s *Session) UnmarshalPlain(n int, b []byte) (int, error) { func (s *Session) UnmarshalPlain(n int, b []byte) (int, error) {
var err error var err error
// ID
n, s.ID, err = bstd.UnmarshalString(n, b) n, s.ID, err = bstd.UnmarshalString(n, b)
if err != nil { if err != nil {
return n, err return n, err
} }
// Data map n, s.Data, err = bstd.UnmarshalMap[string, any](n, b, bstd.UnmarshalString, unmarshalAny)
n, s.Data, err = bstd.UnmarshalMap[string, any](n, b, bstd.UnmarshalString, func(n int, b []byte) (int, any, error) {
return unmarshalAny(n, b)
})
if err != nil { if err != nil {
return n, err return n, err
} }
// Time fields as Unix timestamps var ts int64
var timestamp int64 for _, t := range []*time.Time{&s.CreatedAt, &s.UpdatedAt, &s.LastUsed, &s.Expiry} {
n, ts, err = bstd.UnmarshalInt64(n, b)
n, timestamp, err = bstd.UnmarshalInt64(n, b)
if err != nil { if err != nil {
return n, err return n, err
} }
s.CreatedAt = time.Unix(timestamp, 0) *t = time.Unix(ts, 0)
n, timestamp, err = bstd.UnmarshalInt64(n, b)
if err != nil {
return n, err
} }
s.UpdatedAt = time.Unix(timestamp, 0)
n, timestamp, err = bstd.UnmarshalInt64(n, b)
if err != nil {
return n, err
}
s.LastUsed = time.Unix(timestamp, 0)
n, timestamp, err = bstd.UnmarshalInt64(n, b)
if err != nil {
return n, err
}
s.Expiry = time.Unix(timestamp, 0)
return n, nil return n, nil
} }
// Marshal serializes the session using benc // Marshal serializes the session using benc
func (s *Session) Marshal() ([]byte, error) { func (s *Session) Marshal() ([]byte, error) {
size := s.SizePlain() return bufPool.Marshal(s.SizePlain(), func(b []byte) int {
data, err := bufPool.Marshal(size, func(b []byte) (n int) {
return s.MarshalPlain(0, b) return s.MarshalPlain(0, b)
}) })
if err != nil {
return nil, err
}
return data, nil
} }
// Unmarshal deserializes a session using benc // Unmarshal deserializes a session using benc
func Unmarshal(data []byte) (*Session, error) { func Unmarshal(data []byte) (*Session, error) {
session := GetFromPool() s := sessionPool.Get().(*Session)
_, err := session.UnmarshalPlain(0, data) if _, err := s.UnmarshalPlain(0, data); err != nil {
if err != nil { s.Release()
ReturnToPool(session)
return nil, err return nil, err
} }
return session, nil return s, nil
} }
// Type identifiers for any values // Type identifiers
const ( const (
typeNull byte = 0 typeNull byte = 0
typeString byte = 1 typeString byte = 1
@ -251,32 +203,38 @@ const (
typeFloat byte = 3 typeFloat byte = 3
typeBool byte = 4 typeBool byte = 4
typeBytes byte = 5 typeBytes byte = 5
typeTable byte = 6
typeArray byte = 7
) )
// sizeAny calculates the size needed for any value // sizeAny calculates the size needed for any value
func sizeAny(v any) int { func sizeAny(v any) int {
if v == nil { if v == nil {
return 1 // Just the type byte return 1
} }
// 1 byte for type + size of the value size := 1 // type byte
switch val := v.(type) { switch v := v.(type) {
case string: case string:
return 1 + bstd.SizeString(val) size += bstd.SizeString(v)
case int: case int:
return 1 + bstd.SizeInt64() size += bstd.SizeInt64()
case int64: case int64:
return 1 + bstd.SizeInt64() size += bstd.SizeInt64()
case float64: case float64:
return 1 + bstd.SizeFloat64() size += bstd.SizeFloat64()
case bool: case bool:
return 1 + bstd.SizeBool() size += bstd.SizeBool()
case []byte: case []byte:
return 1 + bstd.SizeBytes(val) size += bstd.SizeBytes(v)
case map[string]any:
size += bstd.SizeMap(v, bstd.SizeString, sizeAny)
case []any:
size += bstd.SizeSlice(v, sizeAny)
default: default:
// Convert unhandled types to string size += bstd.SizeString("unknown")
return 1 + bstd.SizeString("unknown")
} }
return size
} }
// marshalAny serializes any value // marshalAny serializes any value
@ -286,27 +244,32 @@ func marshalAny(n int, b []byte, v any) int {
return n + 1 return n + 1
} }
switch val := v.(type) { switch v := v.(type) {
case string: case string:
b[n] = typeString b[n] = typeString
return bstd.MarshalString(n+1, b, val) return bstd.MarshalString(n+1, b, v)
case int: case int:
b[n] = typeInt b[n] = typeInt
return bstd.MarshalInt64(n+1, b, int64(val)) return bstd.MarshalInt64(n+1, b, int64(v))
case int64: case int64:
b[n] = typeInt b[n] = typeInt
return bstd.MarshalInt64(n+1, b, val) return bstd.MarshalInt64(n+1, b, v)
case float64: case float64:
b[n] = typeFloat b[n] = typeFloat
return bstd.MarshalFloat64(n+1, b, val) return bstd.MarshalFloat64(n+1, b, v)
case bool: case bool:
b[n] = typeBool b[n] = typeBool
return bstd.MarshalBool(n+1, b, val) return bstd.MarshalBool(n+1, b, v)
case []byte: case []byte:
b[n] = typeBytes b[n] = typeBytes
return bstd.MarshalBytes(n+1, b, val) return bstd.MarshalBytes(n+1, b, v)
case map[string]any:
b[n] = typeTable
return bstd.MarshalMap(n+1, b, v, bstd.MarshalString, marshalAny)
case []any:
b[n] = typeArray
return bstd.MarshalSlice(n+1, b, v, marshalAny)
default: default:
// Convert unhandled types to string
b[n] = typeString b[n] = typeString
return bstd.MarshalString(n+1, b, "unknown") return bstd.MarshalString(n+1, b, "unknown")
} }
@ -318,33 +281,68 @@ func unmarshalAny(n int, b []byte) (int, any, error) {
return n, nil, benc.ErrBufTooSmall return n, nil, benc.ErrBufTooSmall
} }
typeId := b[n] switch b[n] {
n++
switch typeId {
case typeNull: case typeNull:
return n, nil, nil return n + 1, nil, nil
case typeString: case typeString:
return bstd.UnmarshalString(n, b) return bstd.UnmarshalString(n+1, b)
case typeInt: case typeInt:
var val int64 n, v, err := bstd.UnmarshalInt64(n+1, b)
var err error return n, v, err
n, val, err = bstd.UnmarshalInt64(n, b)
return n, val, err
case typeFloat: case typeFloat:
var val float64 return bstd.UnmarshalFloat64(n+1, b)
var err error
n, val, err = bstd.UnmarshalFloat64(n, b)
return n, val, err
case typeBool: case typeBool:
var val bool return bstd.UnmarshalBool(n+1, b)
var err error
n, val, err = bstd.UnmarshalBool(n, b)
return n, val, err
case typeBytes: case typeBytes:
return bstd.UnmarshalBytesCopied(n, b) return bstd.UnmarshalBytesCopied(n+1, b)
case typeTable:
return bstd.UnmarshalMap[string, any](n+1, b, bstd.UnmarshalString, unmarshalAny)
case typeArray:
return bstd.UnmarshalSlice[any](n+1, b, unmarshalAny)
default: default:
// Unknown type, return nil return n + 1, nil, nil
return n, nil, nil
} }
} }
// deepCopy creates a deep copy of any value
func deepCopy(v any) any {
switch v := v.(type) {
case map[string]any:
cp := make(map[string]any, len(v))
for k, val := range v {
cp[k] = deepCopy(val)
}
return cp
case []any:
cp := make([]any, len(v))
for i, val := range v {
cp[i] = deepCopy(val)
}
return cp
default:
return v
}
}
// validate ensures a value can be safely serialized
func validate(v any) error {
switch v := v.(type) {
case nil, string, int, int64, float64, bool, []byte:
return nil
case map[string]any:
for k, val := range v {
if err := validate(val); err != nil {
return fmt.Errorf("invalid value for key %q: %w", k, err)
}
}
case []any:
for i, val := range v {
if err := validate(val); err != nil {
return fmt.Errorf("invalid value at index %d: %w", i, err)
}
}
default:
return fmt.Errorf("unsupported type: %T", v)
}
return nil
}