add flash session support, fix template rendering control flow

This commit is contained in:
Sky Johnson 2025-06-04 09:54:13 -05:00
parent 2c0067dfcf
commit 769a8dd2ce
4 changed files with 335 additions and 32 deletions

View File

@ -188,13 +188,19 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
}() }()
session := s.sessionManager.GetSessionFromRequest(ctx) session := s.sessionManager.GetSessionFromRequest(ctx)
// Advance flash data (move current flash to old, clear old)
session.AdvanceFlash()
sessionMap["id"] = session.ID sessionMap["id"] = session.ID
// Only get session data if not empty // Get session data and flash data
if !session.IsEmpty() { if !session.IsEmpty() {
sessionMap["data"] = session.GetAll() sessionMap["data"] = session.GetAll() // This now includes flash data
sessionMap["flash"] = session.GetAllFlash()
} else { } else {
sessionMap["data"] = emptyMap sessionMap["data"] = emptyMap
sessionMap["flash"] = emptyMap
} }
// Set basic context // Set basic context
@ -250,10 +256,11 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
return return
} }
// Handle session updates // Handle session updates including flash data
if len(response.SessionData) > 0 { if len(response.SessionData) > 0 {
if _, clearAll := response.SessionData["__clear_all"]; clearAll { if _, clearAll := response.SessionData["__clear_all"]; clearAll {
session.Clear() session.Clear()
session.ClearFlash() // Also clear flash data
delete(response.SessionData, "__clear_all") delete(response.SessionData, "__clear_all")
} }
@ -266,6 +273,16 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
} }
} }
// Handle flash data from response
if flashData, ok := response.Metadata["flash"].(map[string]any); ok {
for k, v := range flashData {
if err := session.FlashSafe(k, v); err != nil && s.debugMode {
logger.Warnf("Error setting flash data %s: %v", k, err)
}
}
delete(response.Metadata, "flash") // Remove from metadata after processing
}
s.sessionManager.ApplySessionCookie(ctx, session) s.sessionManager.ApplySessionCookie(ctx, session)
runner.ApplyResponse(response, ctx) runner.ApplyResponse(response, ctx)
runner.ReleaseResponse(response) runner.ReleaseResponse(response)

View File

@ -233,17 +233,23 @@ end
-- Template processing with code execution -- Template processing with code execution
function render(template_str, env) function render(template_str, env)
local function get_line(s, ln) local function is_control_structure(code)
for line in s:gmatch("([^\n]*)\n?") do -- Check if code is a control structure that doesn't produce output
if ln == 1 then return line end local trimmed = code:match("^%s*(.-)%s*$")
ln = ln - 1 return trimmed == "else" or
end trimmed == "end" or
end trimmed:match("^if%s") or
trimmed:match("^elseif%s") or
local function pos_to_line(s, pos) trimmed:match("^for%s") or
local line = 1 trimmed:match("^while%s") or
for _ in s:sub(1, pos):gmatch("\n") do line = line + 1 end trimmed:match("^repeat%s*$") or
return line trimmed:match("^until%s") or
trimmed:match("^do%s*$") or
trimmed:match("^local%s") or
trimmed:match("^function%s") or
trimmed:match(".*=%s*function%s*%(") or
trimmed:match(".*then%s*$") or
trimmed:match(".*do%s*$")
end end
local pos, chunks = 1, {} local pos, chunks = 1, {}
@ -273,11 +279,9 @@ function render(template_str, env)
end end
local code = template_str:sub(pos, close_start-1):match("^%s*(.-)%s*$") local code = template_str:sub(pos, close_start-1):match("^%s*(.-)%s*$")
local is_control = is_control_structure(code)
-- Check if it's a simple variable name for escaped output table.insert(chunks, {tag_type, code, pos, is_control})
local is_simple_var = tag_type == "=" and code:match("^[%w_]+$")
table.insert(chunks, {tag_type, code, pos, is_simple_var})
pos = close_stop + 1 pos = close_stop + 1
end end
@ -288,24 +292,43 @@ function render(template_str, env)
table.insert(buffer, "_b_i = _b_i + 1\n") table.insert(buffer, "_b_i = _b_i + 1\n")
table.insert(buffer, "_b[_b_i] = " .. string.format("%q", chunk) .. "\n") table.insert(buffer, "_b[_b_i] = " .. string.format("%q", chunk) .. "\n")
else else
t = chunk[1] local tag_type, code, pos, is_control = chunk[1], chunk[2], chunk[3], chunk[4]
if t == "=" then
if chunk[4] then -- is_simple_var if is_control then
-- Control structure - just insert as raw Lua code
table.insert(buffer, "--[[" .. pos .. "]] " .. code .. "\n")
elseif tag_type == "=" then
-- Simple variable check
if code:match("^[%w_]+$") then
table.insert(buffer, "_b_i = _b_i + 1\n") table.insert(buffer, "_b_i = _b_i + 1\n")
table.insert(buffer, "--[[" .. chunk[3] .. "]] _b[_b_i] = _escape(_tostring(" .. chunk[2] .. "))\n") table.insert(buffer, "--[[" .. pos .. "]] _b[_b_i] = _escape(_tostring(" .. code .. "))\n")
else else
table.insert(buffer, "--[[" .. chunk[3] .. "]] " .. chunk[2] .. "\n") -- Expression output with escaping
table.insert(buffer, "_b_i = _b_i + 1\n")
table.insert(buffer, "--[[" .. pos .. "]] _b[_b_i] = _escape(_tostring(" .. code .. "))\n")
end end
elseif t == "-" then elseif tag_type == "-" then
-- Unescaped output
table.insert(buffer, "_b_i = _b_i + 1\n") table.insert(buffer, "_b_i = _b_i + 1\n")
table.insert(buffer, "--[[" .. chunk[3] .. "]] _b[_b_i] = _tostring(" .. chunk[2] .. ")\n") table.insert(buffer, "--[[" .. pos .. "]] _b[_b_i] = _tostring(" .. code .. ")\n")
end end
end end
end end
table.insert(buffer, "return _b") table.insert(buffer, "return _b")
local fn, err = loadstring(table.concat(buffer)) local generated_code = table.concat(buffer)
if not fn then error(err) end
-- DEBUG: Uncomment to see generated code
-- print("Generated Lua code:")
-- print(generated_code)
-- print("---")
local fn, err = loadstring(generated_code)
if not fn then
print("Generated code that failed to compile:")
print(generated_code)
error(err)
end
env = env or {} env = env or {}
local runtime_env = setmetatable({}, {__index = function(_, k) return env[k] or _G[k] end}) local runtime_env = setmetatable({}, {__index = function(_, k) return env[k] or _G[k] end})
@ -428,3 +451,142 @@ function password_verify(plain_password, hash_string)
return __password_verify(plain_password, hash_string) return __password_verify(plain_password, hash_string)
end end
-- ======================================================================
-- SESSION FLASH FUNCTIONS
-- ======================================================================
function session_flash(key, value)
__response.flash = __response.flash or {}
__response.flash[key] = value
end
function session_get_flash(key)
-- Check current flash data first
if __response.flash and __response.flash[key] ~= nil then
return __response.flash[key]
end
-- Check session flash data
if __ctx.session and __ctx.session.flash and __ctx.session.flash[key] ~= nil then
return __ctx.session.flash[key]
end
return nil
end
function session_has_flash(key)
-- Check current flash
if __response.flash and __response.flash[key] ~= nil then
return true
end
-- Check session flash
if __ctx.session and __ctx.session.flash and __ctx.session.flash[key] ~= nil then
return true
end
return false
end
function session_get_all_flash()
local flash = {}
-- Add session flash data first
if __ctx.session and __ctx.session.flash then
for k, v in pairs(__ctx.session.flash) do
flash[k] = v
end
end
-- Add current response flash (overwrites session flash if same key)
if __response.flash then
for k, v in pairs(__response.flash) do
flash[k] = v
end
end
return flash
end
function session_flash_now(key, value)
-- Flash for current request only (not persisted)
_G._current_flash = _G._current_flash or {}
_G._current_flash[key] = value
end
function session_get_flash_now(key)
return _G._current_flash and _G._current_flash[key]
end
-- ======================================================================
-- FLASH HELPER FUNCTIONS
-- ======================================================================
function flash_success(message)
session_flash("success", message)
end
function flash_error(message)
session_flash("error", message)
end
function flash_warning(message)
session_flash("warning", message)
end
function flash_info(message)
session_flash("info", message)
end
function flash_message(type, message)
session_flash(type, message)
end
-- Get flash messages by type
function get_flash_success()
return session_get_flash("success")
end
function get_flash_error()
return session_get_flash("error")
end
function get_flash_warning()
return session_get_flash("warning")
end
function get_flash_info()
return session_get_flash("info")
end
-- Check if flash messages exist
function has_flash_success()
return session_has_flash("success")
end
function has_flash_error()
return session_has_flash("error")
end
function has_flash_warning()
return session_has_flash("warning")
end
function has_flash_info()
return session_has_flash("info")
end
-- Convenience function for redirects with flash
function redirect_with_flash(url, type, message, status)
session_flash(type or "info", message)
http_redirect(url, status)
end
function redirect_with_success(url, message, status)
redirect_with_flash(url, "success", message, status)
end
function redirect_with_error(url, message, status)
redirect_with_flash(url, "error", message, status)
end

View File

@ -68,6 +68,7 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*
"cookies": []any{}, "cookies": []any{},
"metadata": make(map[string]any), "metadata": make(map[string]any),
"session": make(map[string]any), "session": make(map[string]any),
"flash": make(map[string]any),
} }
// Call __execute(script_func, ctx, response) // Call __execute(script_func, ctx, response)
@ -170,6 +171,11 @@ func (s *Sandbox) buildResponse(luaResp map[string]any, body any) *Response {
maps.Copy(resp.SessionData, session) maps.Copy(resp.SessionData, session)
} }
// Extract flash data and add to metadata for processing by server
if flash, ok := luaResp["flash"].(map[string]any); ok && len(flash) > 0 {
resp.Metadata["flash"] = flash
}
return resp return resp
} }

View File

@ -13,6 +13,8 @@ import (
type Session struct { type Session struct {
ID string ID string
Data map[string]any Data map[string]any
FlashData map[string]any // Flash data for next request
OldFlash map[string]any // Flash data from previous request (to be cleared)
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
LastUsed time.Time LastUsed time.Time
@ -23,7 +25,11 @@ type Session struct {
var ( var (
sessionPool = sync.Pool{ sessionPool = sync.Pool{
New: func() any { New: func() any {
return &Session{Data: make(map[string]any, 8)} return &Session{
Data: make(map[string]any, 8),
FlashData: make(map[string]any, 4),
OldFlash: make(map[string]any, 4),
}
}, },
} }
bufPool = benc.NewBufPool(benc.WithBufferSize(4096)) bufPool = benc.NewBufPool(benc.WithBufferSize(4096))
@ -35,7 +41,9 @@ func NewSession(id string, maxAge int) *Session {
now := time.Now() now := time.Now()
*s = Session{ *s = Session{
ID: id, ID: id,
Data: s.Data, // Reuse map Data: s.Data, // Reuse maps
FlashData: s.FlashData,
OldFlash: s.OldFlash,
CreatedAt: now, CreatedAt: now,
UpdatedAt: now, UpdatedAt: now,
LastUsed: now, LastUsed: now,
@ -49,6 +57,12 @@ func (s *Session) Release() {
for k := range s.Data { for k := range s.Data {
delete(s.Data, k) delete(s.Data, k)
} }
for k := range s.FlashData {
delete(s.FlashData, k)
}
for k := range s.OldFlash {
delete(s.OldFlash, k)
}
sessionPool.Put(s) sessionPool.Put(s)
} }
@ -70,12 +84,22 @@ func (s *Session) GetTable(key string) map[string]any {
return nil return nil
} }
// GetAll returns a deep copy of all session data // GetAll returns a deep copy of all session data including flash data
func (s *Session) GetAll() map[string]any { func (s *Session) GetAll() map[string]any {
copy := make(map[string]any, len(s.Data)) copy := make(map[string]any, len(s.Data)+len(s.FlashData)+len(s.OldFlash))
for k, v := range s.Data { for k, v := range s.Data {
copy[k] = deepCopy(v) copy[k] = deepCopy(v)
} }
// Include current flash data
for k, v := range s.FlashData {
copy[k] = deepCopy(v)
}
// Include old flash data (still available this request)
for k, v := range s.OldFlash {
if _, exists := copy[k]; !exists { // Don't override new flash
copy[k] = deepCopy(v)
}
}
return copy return copy
} }
@ -103,6 +127,77 @@ func (s *Session) SetTable(key string, table map[string]any) error {
return s.SetSafe(key, table) return s.SetSafe(key, table)
} }
// Flash stores a value that will be available for the next request only
func (s *Session) Flash(key string, value any) {
s.FlashData[key] = value
s.UpdatedAt = time.Now()
s.dirty = true
}
// FlashSafe stores a flash value with validation
func (s *Session) FlashSafe(key string, value any) error {
if err := validate(value); err != nil {
return fmt.Errorf("session.FlashSafe: %w", err)
}
s.Flash(key, value)
return nil
}
// GetFlash returns a flash value (from either current or old flash)
func (s *Session) GetFlash(key string) any {
// Check current flash first
if v, ok := s.FlashData[key]; ok {
return deepCopy(v)
}
// Check old flash
if v, ok := s.OldFlash[key]; ok {
return deepCopy(v)
}
return nil
}
// HasFlash checks if a flash key exists
func (s *Session) HasFlash(key string) bool {
_, inNew := s.FlashData[key]
_, inOld := s.OldFlash[key]
return inNew || inOld
}
// GetAllFlash returns all flash data (both current and old)
func (s *Session) GetAllFlash() map[string]any {
flash := make(map[string]any, len(s.FlashData)+len(s.OldFlash))
// Add old flash first
for k, v := range s.OldFlash {
flash[k] = deepCopy(v)
}
// Add current flash (overwrites old if same key)
for k, v := range s.FlashData {
flash[k] = deepCopy(v)
}
return flash
}
// AdvanceFlash moves current flash to old flash and clears old flash
// This should be called at the start of each request
func (s *Session) AdvanceFlash() {
// Clear old flash
for k := range s.OldFlash {
delete(s.OldFlash, k)
}
// Move current flash to old flash
if len(s.FlashData) > 0 {
for k, v := range s.FlashData {
s.OldFlash[k] = v
}
// Clear current flash
for k := range s.FlashData {
delete(s.FlashData, k)
}
s.dirty = true
}
}
// Delete removes a value from the session // Delete removes a value from the session
func (s *Session) Delete(key string) { func (s *Session) Delete(key string) {
delete(s.Data, key) delete(s.Data, key)
@ -117,6 +212,15 @@ func (s *Session) Clear() {
s.dirty = true s.dirty = true
} }
// ClearFlash removes all flash data
func (s *Session) ClearFlash() {
if len(s.FlashData) > 0 || len(s.OldFlash) > 0 {
s.FlashData = make(map[string]any, 4)
s.OldFlash = make(map[string]any, 4)
s.dirty = true
}
}
// IsExpired checks if the session has expired // IsExpired checks if the session has expired
func (s *Session) IsExpired() bool { func (s *Session) IsExpired() bool {
return time.Now().After(s.Expiry) return time.Now().After(s.Expiry)
@ -144,6 +248,8 @@ func (s *Session) ResetDirty() {
func (s *Session) SizePlain() int { func (s *Session) SizePlain() int {
return bstd.SizeString(s.ID) + return bstd.SizeString(s.ID) +
bstd.SizeMap(s.Data, bstd.SizeString, sizeAny) + bstd.SizeMap(s.Data, bstd.SizeString, sizeAny) +
bstd.SizeMap(s.FlashData, bstd.SizeString, sizeAny) +
bstd.SizeMap(s.OldFlash, bstd.SizeString, sizeAny) +
bstd.SizeInt64()*4 bstd.SizeInt64()*4
} }
@ -151,6 +257,8 @@ func (s *Session) SizePlain() int {
func (s *Session) MarshalPlain(n int, b []byte) int { func (s *Session) MarshalPlain(n int, b []byte) int {
n = bstd.MarshalString(n, b, s.ID) n = bstd.MarshalString(n, b, s.ID)
n = bstd.MarshalMap(n, b, s.Data, bstd.MarshalString, marshalAny) n = bstd.MarshalMap(n, b, s.Data, bstd.MarshalString, marshalAny)
n = bstd.MarshalMap(n, b, s.FlashData, bstd.MarshalString, marshalAny)
n = bstd.MarshalMap(n, b, s.OldFlash, bstd.MarshalString, marshalAny)
n = bstd.MarshalInt64(n, b, s.CreatedAt.Unix()) n = bstd.MarshalInt64(n, b, s.CreatedAt.Unix())
n = bstd.MarshalInt64(n, b, s.UpdatedAt.Unix()) n = bstd.MarshalInt64(n, b, s.UpdatedAt.Unix())
n = bstd.MarshalInt64(n, b, s.LastUsed.Unix()) n = bstd.MarshalInt64(n, b, s.LastUsed.Unix())
@ -170,6 +278,16 @@ func (s *Session) UnmarshalPlain(n int, b []byte) (int, error) {
return n, err return n, err
} }
n, s.FlashData, err = bstd.UnmarshalMap[string, any](n, b, bstd.UnmarshalString, unmarshalAny)
if err != nil {
return n, err
}
n, s.OldFlash, err = bstd.UnmarshalMap[string, any](n, b, bstd.UnmarshalString, unmarshalAny)
if err != nil {
return n, err
}
var ts int64 var ts int64
for _, t := range []*time.Time{&s.CreatedAt, &s.UpdatedAt, &s.LastUsed, &s.Expiry} { for _, t := range []*time.Time{&s.CreatedAt, &s.UpdatedAt, &s.LastUsed, &s.Expiry} {
n, ts, err = bstd.UnmarshalInt64(n, b) n, ts, err = bstd.UnmarshalInt64(n, b)
@ -430,5 +548,5 @@ func deepEqual(a, b any) bool {
// IsEmpty returns true if the session has no data // IsEmpty returns true if the session has no data
func (s *Session) IsEmpty() bool { func (s *Session) IsEmpty() bool {
return len(s.Data) == 0 return len(s.Data) == 0 && len(s.FlashData) == 0 && len(s.OldFlash) == 0
} }