From 50f4cb91f6a32c0a38195679ef1d09a0b46ac8c4 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Thu, 10 Apr 2025 14:27:50 -0500 Subject: [PATCH] finalize session management --- core/sessions/Session.go | 179 ++++++++++++++++++++++++--------------- 1 file changed, 109 insertions(+), 70 deletions(-) diff --git a/core/sessions/Session.go b/core/sessions/Session.go index 992c1fd..83b92e8 100644 --- a/core/sessions/Session.go +++ b/core/sessions/Session.go @@ -139,33 +139,26 @@ func (s *Session) SizePlain() (size int) { // ID size += bstd.SizeString(s.ID) - // Data (map of string to any) - // For simplicity, we store data as binary-encoded strings - // This is a simplification, in a real-world scenario you would handle - // different types differently - dataAsStrings := make(map[string]string) - for k, v := range s.Data { - dataAsStrings[k] = toString(v) - } - size += bstd.SizeMap(dataAsStrings, bstd.SizeString, bstd.SizeString) + // Data map + size += bstd.SizeMap(s.Data, bstd.SizeString, func(v any) int { + return sizeAny(v) + }) - // Time fields - size += bstd.SizeInt64() * 4 // Store Unix timestamps for all time fields + // Time fields stored as int64 Unix timestamps + size += bstd.SizeInt64() * 4 return size } // MarshalPlain serializes the session to binary -func (s *Session) MarshalPlain(n int, b []byte) (int, error) { +func (s *Session) MarshalPlain(n int, b []byte) int { // ID n = bstd.MarshalString(n, b, s.ID) - // Data - dataAsStrings := make(map[string]string) - for k, v := range s.Data { - dataAsStrings[k] = toString(v) - } - n = bstd.MarshalMap(n, b, dataAsStrings, bstd.MarshalString, bstd.MarshalString) + // Data map + n = bstd.MarshalMap(n, b, s.Data, bstd.MarshalString, func(n int, b []byte, v any) int { + return marshalAny(n, b, v) + }) // Time fields as Unix timestamps n = bstd.MarshalInt64(n, b, s.CreatedAt.Unix()) @@ -173,7 +166,7 @@ func (s *Session) MarshalPlain(n int, b []byte) (int, error) { n = bstd.MarshalInt64(n, b, s.LastUsed.Unix()) n = bstd.MarshalInt64(n, b, s.Expiry.Unix()) - return n, nil + return n } // UnmarshalPlain deserializes the session from binary @@ -186,44 +179,35 @@ func (s *Session) UnmarshalPlain(n int, b []byte) (int, error) { return n, err } - // Data - var dataAsStrings map[string]string - n, dataAsStrings, err = bstd.UnmarshalMap[string, string](n, b, bstd.UnmarshalString, bstd.UnmarshalString) + // Data map + n, s.Data, err = bstd.UnmarshalMap[string, any](n, b, bstd.UnmarshalString, func(n int, b []byte) (int, any, error) { + return unmarshalAny(n, b) + }) if err != nil { return n, err } - // Convert string data back to original types - s.Data = make(map[string]any, len(dataAsStrings)) - for k, v := range dataAsStrings { - s.Data[k] = fromString(v) - } - - // Time fields + // Time fields as Unix timestamps var timestamp int64 - // CreatedAt n, timestamp, err = bstd.UnmarshalInt64(n, b) if err != nil { return n, err } s.CreatedAt = time.Unix(timestamp, 0) - // UpdatedAt n, timestamp, err = bstd.UnmarshalInt64(n, b) if err != nil { return n, err } s.UpdatedAt = time.Unix(timestamp, 0) - // LastUsed n, timestamp, err = bstd.UnmarshalInt64(n, b) if err != nil { return n, err } s.LastUsed = time.Unix(timestamp, 0) - // Expiry n, timestamp, err = bstd.UnmarshalInt64(n, b) if err != nil { return n, err @@ -238,8 +222,7 @@ func (s *Session) Marshal() ([]byte, error) { size := s.SizePlain() data, err := bufPool.Marshal(size, func(b []byte) (n int) { - n, _ = s.MarshalPlain(0, b) - return n + return s.MarshalPlain(0, b) }) if err != nil { @@ -260,52 +243,108 @@ func Unmarshal(data []byte) (*Session, error) { return session, nil } -// Helper functions to convert between any and string -// In a production environment, you would use a more robust serialization method for the map values -func toString(v any) string { +// Type identifiers for any values +const ( + typeNull byte = 0 + typeString byte = 1 + typeInt byte = 2 + typeFloat byte = 3 + typeBool byte = 4 + typeBytes byte = 5 +) + +// sizeAny calculates the size needed for any value +func sizeAny(v any) int { if v == nil { - return "" + return 1 // Just the type byte } - switch t := v.(type) { + + // 1 byte for type + size of the value + switch val := v.(type) { case string: - return t - case []byte: - return string(t) + return 1 + bstd.SizeString(val) case int: - return "i:" + string(rune(t)) + return 1 + bstd.SizeInt64() + case int64: + return 1 + bstd.SizeInt64() + case float64: + return 1 + bstd.SizeFloat64() case bool: - if t { - return "b:t" - } - return "b:f" + return 1 + bstd.SizeBool() + case []byte: + return 1 + bstd.SizeBytes(val) default: - return "u:" // unknown type + // Convert unhandled types to string + return 1 + bstd.SizeString("unknown") } } -func fromString(s string) any { - if s == "" { - return nil - } - if len(s) < 2 { - return s +// marshalAny serializes any value +func marshalAny(n int, b []byte, v any) int { + if v == nil { + b[n] = typeNull + return n + 1 } - prefix := s[:2] - switch prefix { - case "i:": - if len(s) > 2 { - return int(rune(s[2])) - } - return 0 - case "b:": - if len(s) > 2 && s[2] == 't' { - return true - } - return false - case "u:": - return nil + switch val := v.(type) { + case string: + b[n] = typeString + return bstd.MarshalString(n+1, b, val) + case int: + b[n] = typeInt + return bstd.MarshalInt64(n+1, b, int64(val)) + case int64: + b[n] = typeInt + return bstd.MarshalInt64(n+1, b, val) + case float64: + b[n] = typeFloat + return bstd.MarshalFloat64(n+1, b, val) + case bool: + b[n] = typeBool + return bstd.MarshalBool(n+1, b, val) + case []byte: + b[n] = typeBytes + return bstd.MarshalBytes(n+1, b, val) default: - return s + // Convert unhandled types to string + b[n] = typeString + return bstd.MarshalString(n+1, b, "unknown") + } +} + +// unmarshalAny deserializes any value +func unmarshalAny(n int, b []byte) (int, any, error) { + if len(b) <= n { + return n, nil, benc.ErrBufTooSmall + } + + typeId := b[n] + n++ + + switch typeId { + case typeNull: + return n, nil, nil + case typeString: + return bstd.UnmarshalString(n, b) + case typeInt: + var val int64 + var err error + n, val, err = bstd.UnmarshalInt64(n, b) + return n, val, err + case typeFloat: + var val float64 + var err error + n, val, err = bstd.UnmarshalFloat64(n, b) + return n, val, err + case typeBool: + var val bool + var err error + n, val, err = bstd.UnmarshalBool(n, b) + return n, val, err + case typeBytes: + return bstd.UnmarshalBytesCopied(n, b) + default: + // Unknown type, return nil + return n, nil, nil } }