add flash session support, fix template rendering control flow
This commit is contained in:
parent
2c0067dfcf
commit
769a8dd2ce
@ -188,13 +188,19 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
||||
}()
|
||||
|
||||
session := s.sessionManager.GetSessionFromRequest(ctx)
|
||||
|
||||
// Advance flash data (move current flash to old, clear old)
|
||||
session.AdvanceFlash()
|
||||
|
||||
sessionMap["id"] = session.ID
|
||||
|
||||
// Only get session data if not empty
|
||||
// Get session data and flash data
|
||||
if !session.IsEmpty() {
|
||||
sessionMap["data"] = session.GetAll()
|
||||
sessionMap["data"] = session.GetAll() // This now includes flash data
|
||||
sessionMap["flash"] = session.GetAllFlash()
|
||||
} else {
|
||||
sessionMap["data"] = emptyMap
|
||||
sessionMap["flash"] = emptyMap
|
||||
}
|
||||
|
||||
// Set basic context
|
||||
@ -250,10 +256,11 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
||||
return
|
||||
}
|
||||
|
||||
// Handle session updates
|
||||
// Handle session updates including flash data
|
||||
if len(response.SessionData) > 0 {
|
||||
if _, clearAll := response.SessionData["__clear_all"]; clearAll {
|
||||
session.Clear()
|
||||
session.ClearFlash() // Also clear flash data
|
||||
delete(response.SessionData, "__clear_all")
|
||||
}
|
||||
|
||||
@ -266,6 +273,16 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
||||
}
|
||||
}
|
||||
|
||||
// Handle flash data from response
|
||||
if flashData, ok := response.Metadata["flash"].(map[string]any); ok {
|
||||
for k, v := range flashData {
|
||||
if err := session.FlashSafe(k, v); err != nil && s.debugMode {
|
||||
logger.Warnf("Error setting flash data %s: %v", k, err)
|
||||
}
|
||||
}
|
||||
delete(response.Metadata, "flash") // Remove from metadata after processing
|
||||
}
|
||||
|
||||
s.sessionManager.ApplySessionCookie(ctx, session)
|
||||
runner.ApplyResponse(response, ctx)
|
||||
runner.ReleaseResponse(response)
|
||||
|
@ -233,17 +233,23 @@ end
|
||||
|
||||
-- Template processing with code execution
|
||||
function render(template_str, env)
|
||||
local function get_line(s, ln)
|
||||
for line in s:gmatch("([^\n]*)\n?") do
|
||||
if ln == 1 then return line end
|
||||
ln = ln - 1
|
||||
end
|
||||
end
|
||||
|
||||
local function pos_to_line(s, pos)
|
||||
local line = 1
|
||||
for _ in s:sub(1, pos):gmatch("\n") do line = line + 1 end
|
||||
return line
|
||||
local function is_control_structure(code)
|
||||
-- Check if code is a control structure that doesn't produce output
|
||||
local trimmed = code:match("^%s*(.-)%s*$")
|
||||
return trimmed == "else" or
|
||||
trimmed == "end" or
|
||||
trimmed:match("^if%s") or
|
||||
trimmed:match("^elseif%s") or
|
||||
trimmed:match("^for%s") or
|
||||
trimmed:match("^while%s") or
|
||||
trimmed:match("^repeat%s*$") or
|
||||
trimmed:match("^until%s") or
|
||||
trimmed:match("^do%s*$") or
|
||||
trimmed:match("^local%s") or
|
||||
trimmed:match("^function%s") or
|
||||
trimmed:match(".*=%s*function%s*%(") or
|
||||
trimmed:match(".*then%s*$") or
|
||||
trimmed:match(".*do%s*$")
|
||||
end
|
||||
|
||||
local pos, chunks = 1, {}
|
||||
@ -273,11 +279,9 @@ function render(template_str, env)
|
||||
end
|
||||
|
||||
local code = template_str:sub(pos, close_start-1):match("^%s*(.-)%s*$")
|
||||
local is_control = is_control_structure(code)
|
||||
|
||||
-- Check if it's a simple variable name for escaped output
|
||||
local is_simple_var = tag_type == "=" and code:match("^[%w_]+$")
|
||||
|
||||
table.insert(chunks, {tag_type, code, pos, is_simple_var})
|
||||
table.insert(chunks, {tag_type, code, pos, is_control})
|
||||
pos = close_stop + 1
|
||||
end
|
||||
|
||||
@ -288,24 +292,43 @@ function render(template_str, env)
|
||||
table.insert(buffer, "_b_i = _b_i + 1\n")
|
||||
table.insert(buffer, "_b[_b_i] = " .. string.format("%q", chunk) .. "\n")
|
||||
else
|
||||
t = chunk[1]
|
||||
if t == "=" then
|
||||
if chunk[4] then -- is_simple_var
|
||||
local tag_type, code, pos, is_control = chunk[1], chunk[2], chunk[3], chunk[4]
|
||||
|
||||
if is_control then
|
||||
-- Control structure - just insert as raw Lua code
|
||||
table.insert(buffer, "--[[" .. pos .. "]] " .. code .. "\n")
|
||||
elseif tag_type == "=" then
|
||||
-- Simple variable check
|
||||
if code:match("^[%w_]+$") then
|
||||
table.insert(buffer, "_b_i = _b_i + 1\n")
|
||||
table.insert(buffer, "--[[" .. chunk[3] .. "]] _b[_b_i] = _escape(_tostring(" .. chunk[2] .. "))\n")
|
||||
table.insert(buffer, "--[[" .. pos .. "]] _b[_b_i] = _escape(_tostring(" .. code .. "))\n")
|
||||
else
|
||||
table.insert(buffer, "--[[" .. chunk[3] .. "]] " .. chunk[2] .. "\n")
|
||||
-- Expression output with escaping
|
||||
table.insert(buffer, "_b_i = _b_i + 1\n")
|
||||
table.insert(buffer, "--[[" .. pos .. "]] _b[_b_i] = _escape(_tostring(" .. code .. "))\n")
|
||||
end
|
||||
elseif t == "-" then
|
||||
elseif tag_type == "-" then
|
||||
-- Unescaped output
|
||||
table.insert(buffer, "_b_i = _b_i + 1\n")
|
||||
table.insert(buffer, "--[[" .. chunk[3] .. "]] _b[_b_i] = _tostring(" .. chunk[2] .. ")\n")
|
||||
table.insert(buffer, "--[[" .. pos .. "]] _b[_b_i] = _tostring(" .. code .. ")\n")
|
||||
end
|
||||
end
|
||||
end
|
||||
table.insert(buffer, "return _b")
|
||||
|
||||
local fn, err = loadstring(table.concat(buffer))
|
||||
if not fn then error(err) end
|
||||
local generated_code = table.concat(buffer)
|
||||
|
||||
-- DEBUG: Uncomment to see generated code
|
||||
-- print("Generated Lua code:")
|
||||
-- print(generated_code)
|
||||
-- print("---")
|
||||
|
||||
local fn, err = loadstring(generated_code)
|
||||
if not fn then
|
||||
print("Generated code that failed to compile:")
|
||||
print(generated_code)
|
||||
error(err)
|
||||
end
|
||||
|
||||
env = env or {}
|
||||
local runtime_env = setmetatable({}, {__index = function(_, k) return env[k] or _G[k] end})
|
||||
@ -428,3 +451,142 @@ function password_verify(plain_password, hash_string)
|
||||
|
||||
return __password_verify(plain_password, hash_string)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- SESSION FLASH FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
function session_flash(key, value)
|
||||
__response.flash = __response.flash or {}
|
||||
__response.flash[key] = value
|
||||
end
|
||||
|
||||
function session_get_flash(key)
|
||||
-- Check current flash data first
|
||||
if __response.flash and __response.flash[key] ~= nil then
|
||||
return __response.flash[key]
|
||||
end
|
||||
|
||||
-- Check session flash data
|
||||
if __ctx.session and __ctx.session.flash and __ctx.session.flash[key] ~= nil then
|
||||
return __ctx.session.flash[key]
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
function session_has_flash(key)
|
||||
-- Check current flash
|
||||
if __response.flash and __response.flash[key] ~= nil then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Check session flash
|
||||
if __ctx.session and __ctx.session.flash and __ctx.session.flash[key] ~= nil then
|
||||
return true
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
function session_get_all_flash()
|
||||
local flash = {}
|
||||
|
||||
-- Add session flash data first
|
||||
if __ctx.session and __ctx.session.flash then
|
||||
for k, v in pairs(__ctx.session.flash) do
|
||||
flash[k] = v
|
||||
end
|
||||
end
|
||||
|
||||
-- Add current response flash (overwrites session flash if same key)
|
||||
if __response.flash then
|
||||
for k, v in pairs(__response.flash) do
|
||||
flash[k] = v
|
||||
end
|
||||
end
|
||||
|
||||
return flash
|
||||
end
|
||||
|
||||
function session_flash_now(key, value)
|
||||
-- Flash for current request only (not persisted)
|
||||
_G._current_flash = _G._current_flash or {}
|
||||
_G._current_flash[key] = value
|
||||
end
|
||||
|
||||
function session_get_flash_now(key)
|
||||
return _G._current_flash and _G._current_flash[key]
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- FLASH HELPER FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
function flash_success(message)
|
||||
session_flash("success", message)
|
||||
end
|
||||
|
||||
function flash_error(message)
|
||||
session_flash("error", message)
|
||||
end
|
||||
|
||||
function flash_warning(message)
|
||||
session_flash("warning", message)
|
||||
end
|
||||
|
||||
function flash_info(message)
|
||||
session_flash("info", message)
|
||||
end
|
||||
|
||||
function flash_message(type, message)
|
||||
session_flash(type, message)
|
||||
end
|
||||
|
||||
-- Get flash messages by type
|
||||
function get_flash_success()
|
||||
return session_get_flash("success")
|
||||
end
|
||||
|
||||
function get_flash_error()
|
||||
return session_get_flash("error")
|
||||
end
|
||||
|
||||
function get_flash_warning()
|
||||
return session_get_flash("warning")
|
||||
end
|
||||
|
||||
function get_flash_info()
|
||||
return session_get_flash("info")
|
||||
end
|
||||
|
||||
-- Check if flash messages exist
|
||||
function has_flash_success()
|
||||
return session_has_flash("success")
|
||||
end
|
||||
|
||||
function has_flash_error()
|
||||
return session_has_flash("error")
|
||||
end
|
||||
|
||||
function has_flash_warning()
|
||||
return session_has_flash("warning")
|
||||
end
|
||||
|
||||
function has_flash_info()
|
||||
return session_has_flash("info")
|
||||
end
|
||||
|
||||
-- Convenience function for redirects with flash
|
||||
function redirect_with_flash(url, type, message, status)
|
||||
session_flash(type or "info", message)
|
||||
http_redirect(url, status)
|
||||
end
|
||||
|
||||
function redirect_with_success(url, message, status)
|
||||
redirect_with_flash(url, "success", message, status)
|
||||
end
|
||||
|
||||
function redirect_with_error(url, message, status)
|
||||
redirect_with_flash(url, "error", message, status)
|
||||
end
|
||||
|
@ -68,6 +68,7 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*
|
||||
"cookies": []any{},
|
||||
"metadata": make(map[string]any),
|
||||
"session": make(map[string]any),
|
||||
"flash": make(map[string]any),
|
||||
}
|
||||
|
||||
// Call __execute(script_func, ctx, response)
|
||||
@ -170,6 +171,11 @@ func (s *Sandbox) buildResponse(luaResp map[string]any, body any) *Response {
|
||||
maps.Copy(resp.SessionData, session)
|
||||
}
|
||||
|
||||
// Extract flash data and add to metadata for processing by server
|
||||
if flash, ok := luaResp["flash"].(map[string]any); ok && len(flash) > 0 {
|
||||
resp.Metadata["flash"] = flash
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,8 @@ import (
|
||||
type Session struct {
|
||||
ID string
|
||||
Data map[string]any
|
||||
FlashData map[string]any // Flash data for next request
|
||||
OldFlash map[string]any // Flash data from previous request (to be cleared)
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
LastUsed time.Time
|
||||
@ -23,7 +25,11 @@ type Session struct {
|
||||
var (
|
||||
sessionPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &Session{Data: make(map[string]any, 8)}
|
||||
return &Session{
|
||||
Data: make(map[string]any, 8),
|
||||
FlashData: make(map[string]any, 4),
|
||||
OldFlash: make(map[string]any, 4),
|
||||
}
|
||||
},
|
||||
}
|
||||
bufPool = benc.NewBufPool(benc.WithBufferSize(4096))
|
||||
@ -35,7 +41,9 @@ func NewSession(id string, maxAge int) *Session {
|
||||
now := time.Now()
|
||||
*s = Session{
|
||||
ID: id,
|
||||
Data: s.Data, // Reuse map
|
||||
Data: s.Data, // Reuse maps
|
||||
FlashData: s.FlashData,
|
||||
OldFlash: s.OldFlash,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
LastUsed: now,
|
||||
@ -49,6 +57,12 @@ func (s *Session) Release() {
|
||||
for k := range s.Data {
|
||||
delete(s.Data, k)
|
||||
}
|
||||
for k := range s.FlashData {
|
||||
delete(s.FlashData, k)
|
||||
}
|
||||
for k := range s.OldFlash {
|
||||
delete(s.OldFlash, k)
|
||||
}
|
||||
sessionPool.Put(s)
|
||||
}
|
||||
|
||||
@ -70,12 +84,22 @@ func (s *Session) GetTable(key string) map[string]any {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAll returns a deep copy of all session data
|
||||
// GetAll returns a deep copy of all session data including flash data
|
||||
func (s *Session) GetAll() map[string]any {
|
||||
copy := make(map[string]any, len(s.Data))
|
||||
copy := make(map[string]any, len(s.Data)+len(s.FlashData)+len(s.OldFlash))
|
||||
for k, v := range s.Data {
|
||||
copy[k] = deepCopy(v)
|
||||
}
|
||||
// Include current flash data
|
||||
for k, v := range s.FlashData {
|
||||
copy[k] = deepCopy(v)
|
||||
}
|
||||
// Include old flash data (still available this request)
|
||||
for k, v := range s.OldFlash {
|
||||
if _, exists := copy[k]; !exists { // Don't override new flash
|
||||
copy[k] = deepCopy(v)
|
||||
}
|
||||
}
|
||||
return copy
|
||||
}
|
||||
|
||||
@ -103,6 +127,77 @@ func (s *Session) SetTable(key string, table map[string]any) error {
|
||||
return s.SetSafe(key, table)
|
||||
}
|
||||
|
||||
// Flash stores a value that will be available for the next request only
|
||||
func (s *Session) Flash(key string, value any) {
|
||||
s.FlashData[key] = value
|
||||
s.UpdatedAt = time.Now()
|
||||
s.dirty = true
|
||||
}
|
||||
|
||||
// FlashSafe stores a flash value with validation
|
||||
func (s *Session) FlashSafe(key string, value any) error {
|
||||
if err := validate(value); err != nil {
|
||||
return fmt.Errorf("session.FlashSafe: %w", err)
|
||||
}
|
||||
s.Flash(key, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFlash returns a flash value (from either current or old flash)
|
||||
func (s *Session) GetFlash(key string) any {
|
||||
// Check current flash first
|
||||
if v, ok := s.FlashData[key]; ok {
|
||||
return deepCopy(v)
|
||||
}
|
||||
// Check old flash
|
||||
if v, ok := s.OldFlash[key]; ok {
|
||||
return deepCopy(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasFlash checks if a flash key exists
|
||||
func (s *Session) HasFlash(key string) bool {
|
||||
_, inNew := s.FlashData[key]
|
||||
_, inOld := s.OldFlash[key]
|
||||
return inNew || inOld
|
||||
}
|
||||
|
||||
// GetAllFlash returns all flash data (both current and old)
|
||||
func (s *Session) GetAllFlash() map[string]any {
|
||||
flash := make(map[string]any, len(s.FlashData)+len(s.OldFlash))
|
||||
// Add old flash first
|
||||
for k, v := range s.OldFlash {
|
||||
flash[k] = deepCopy(v)
|
||||
}
|
||||
// Add current flash (overwrites old if same key)
|
||||
for k, v := range s.FlashData {
|
||||
flash[k] = deepCopy(v)
|
||||
}
|
||||
return flash
|
||||
}
|
||||
|
||||
// AdvanceFlash moves current flash to old flash and clears old flash
|
||||
// This should be called at the start of each request
|
||||
func (s *Session) AdvanceFlash() {
|
||||
// Clear old flash
|
||||
for k := range s.OldFlash {
|
||||
delete(s.OldFlash, k)
|
||||
}
|
||||
|
||||
// Move current flash to old flash
|
||||
if len(s.FlashData) > 0 {
|
||||
for k, v := range s.FlashData {
|
||||
s.OldFlash[k] = v
|
||||
}
|
||||
// Clear current flash
|
||||
for k := range s.FlashData {
|
||||
delete(s.FlashData, k)
|
||||
}
|
||||
s.dirty = true
|
||||
}
|
||||
}
|
||||
|
||||
// Delete removes a value from the session
|
||||
func (s *Session) Delete(key string) {
|
||||
delete(s.Data, key)
|
||||
@ -117,6 +212,15 @@ func (s *Session) Clear() {
|
||||
s.dirty = true
|
||||
}
|
||||
|
||||
// ClearFlash removes all flash data
|
||||
func (s *Session) ClearFlash() {
|
||||
if len(s.FlashData) > 0 || len(s.OldFlash) > 0 {
|
||||
s.FlashData = make(map[string]any, 4)
|
||||
s.OldFlash = make(map[string]any, 4)
|
||||
s.dirty = true
|
||||
}
|
||||
}
|
||||
|
||||
// IsExpired checks if the session has expired
|
||||
func (s *Session) IsExpired() bool {
|
||||
return time.Now().After(s.Expiry)
|
||||
@ -144,6 +248,8 @@ func (s *Session) ResetDirty() {
|
||||
func (s *Session) SizePlain() int {
|
||||
return bstd.SizeString(s.ID) +
|
||||
bstd.SizeMap(s.Data, bstd.SizeString, sizeAny) +
|
||||
bstd.SizeMap(s.FlashData, bstd.SizeString, sizeAny) +
|
||||
bstd.SizeMap(s.OldFlash, bstd.SizeString, sizeAny) +
|
||||
bstd.SizeInt64()*4
|
||||
}
|
||||
|
||||
@ -151,6 +257,8 @@ func (s *Session) SizePlain() int {
|
||||
func (s *Session) MarshalPlain(n int, b []byte) int {
|
||||
n = bstd.MarshalString(n, b, s.ID)
|
||||
n = bstd.MarshalMap(n, b, s.Data, bstd.MarshalString, marshalAny)
|
||||
n = bstd.MarshalMap(n, b, s.FlashData, bstd.MarshalString, marshalAny)
|
||||
n = bstd.MarshalMap(n, b, s.OldFlash, bstd.MarshalString, marshalAny)
|
||||
n = bstd.MarshalInt64(n, b, s.CreatedAt.Unix())
|
||||
n = bstd.MarshalInt64(n, b, s.UpdatedAt.Unix())
|
||||
n = bstd.MarshalInt64(n, b, s.LastUsed.Unix())
|
||||
@ -170,6 +278,16 @@ func (s *Session) UnmarshalPlain(n int, b []byte) (int, error) {
|
||||
return n, err
|
||||
}
|
||||
|
||||
n, s.FlashData, err = bstd.UnmarshalMap[string, any](n, b, bstd.UnmarshalString, unmarshalAny)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
n, s.OldFlash, err = bstd.UnmarshalMap[string, any](n, b, bstd.UnmarshalString, unmarshalAny)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
var ts int64
|
||||
for _, t := range []*time.Time{&s.CreatedAt, &s.UpdatedAt, &s.LastUsed, &s.Expiry} {
|
||||
n, ts, err = bstd.UnmarshalInt64(n, b)
|
||||
@ -430,5 +548,5 @@ func deepEqual(a, b any) bool {
|
||||
|
||||
// IsEmpty returns true if the session has no data
|
||||
func (s *Session) IsEmpty() bool {
|
||||
return len(s.Data) == 0
|
||||
return len(s.Data) == 0 && len(s.FlashData) == 0 && len(s.OldFlash) == 0
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user