diff --git a/http/server.go b/http/server.go index 2af501a..a28d662 100644 --- a/http/server.go +++ b/http/server.go @@ -188,13 +188,19 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip }() session := s.sessionManager.GetSessionFromRequest(ctx) + + // Advance flash data (move current flash to old, clear old) + session.AdvanceFlash() + sessionMap["id"] = session.ID - // Only get session data if not empty + // Get session data and flash data if !session.IsEmpty() { - sessionMap["data"] = session.GetAll() + sessionMap["data"] = session.GetAll() // This now includes flash data + sessionMap["flash"] = session.GetAllFlash() } else { sessionMap["data"] = emptyMap + sessionMap["flash"] = emptyMap } // Set basic context @@ -250,10 +256,11 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip return } - // Handle session updates + // Handle session updates including flash data if len(response.SessionData) > 0 { if _, clearAll := response.SessionData["__clear_all"]; clearAll { session.Clear() + session.ClearFlash() // Also clear flash data delete(response.SessionData, "__clear_all") } @@ -266,6 +273,16 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip } } + // Handle flash data from response + if flashData, ok := response.Metadata["flash"].(map[string]any); ok { + for k, v := range flashData { + if err := session.FlashSafe(k, v); err != nil && s.debugMode { + logger.Warnf("Error setting flash data %s: %v", k, err) + } + } + delete(response.Metadata, "flash") // Remove from metadata after processing + } + s.sessionManager.ApplySessionCookie(ctx, session) runner.ApplyResponse(response, ctx) runner.ReleaseResponse(response) diff --git a/runner/lua/sandbox.lua b/runner/lua/sandbox.lua index fb65982..dfd61d6 100644 --- a/runner/lua/sandbox.lua +++ b/runner/lua/sandbox.lua @@ -233,17 +233,23 @@ end -- Template processing with code execution function render(template_str, env) - local function get_line(s, ln) - for line in s:gmatch("([^\n]*)\n?") do - if ln == 1 then return line end - ln = ln - 1 - end - end - - local function pos_to_line(s, pos) - local line = 1 - for _ in s:sub(1, pos):gmatch("\n") do line = line + 1 end - return line + local function is_control_structure(code) + -- Check if code is a control structure that doesn't produce output + local trimmed = code:match("^%s*(.-)%s*$") + return trimmed == "else" or + trimmed == "end" or + trimmed:match("^if%s") or + trimmed:match("^elseif%s") or + trimmed:match("^for%s") or + trimmed:match("^while%s") or + trimmed:match("^repeat%s*$") or + trimmed:match("^until%s") or + trimmed:match("^do%s*$") or + trimmed:match("^local%s") or + trimmed:match("^function%s") or + trimmed:match(".*=%s*function%s*%(") or + trimmed:match(".*then%s*$") or + trimmed:match(".*do%s*$") end local pos, chunks = 1, {} @@ -273,11 +279,9 @@ function render(template_str, env) end local code = template_str:sub(pos, close_start-1):match("^%s*(.-)%s*$") + local is_control = is_control_structure(code) - -- Check if it's a simple variable name for escaped output - local is_simple_var = tag_type == "=" and code:match("^[%w_]+$") - - table.insert(chunks, {tag_type, code, pos, is_simple_var}) + table.insert(chunks, {tag_type, code, pos, is_control}) pos = close_stop + 1 end @@ -288,24 +292,43 @@ function render(template_str, env) table.insert(buffer, "_b_i = _b_i + 1\n") table.insert(buffer, "_b[_b_i] = " .. string.format("%q", chunk) .. "\n") else - t = chunk[1] - if t == "=" then - if chunk[4] then -- is_simple_var + local tag_type, code, pos, is_control = chunk[1], chunk[2], chunk[3], chunk[4] + + if is_control then + -- Control structure - just insert as raw Lua code + table.insert(buffer, "--[[" .. pos .. "]] " .. code .. "\n") + elseif tag_type == "=" then + -- Simple variable check + if code:match("^[%w_]+$") then table.insert(buffer, "_b_i = _b_i + 1\n") - table.insert(buffer, "--[[" .. chunk[3] .. "]] _b[_b_i] = _escape(_tostring(" .. chunk[2] .. "))\n") + table.insert(buffer, "--[[" .. pos .. "]] _b[_b_i] = _escape(_tostring(" .. code .. "))\n") else - table.insert(buffer, "--[[" .. chunk[3] .. "]] " .. chunk[2] .. "\n") + -- Expression output with escaping + table.insert(buffer, "_b_i = _b_i + 1\n") + table.insert(buffer, "--[[" .. pos .. "]] _b[_b_i] = _escape(_tostring(" .. code .. "))\n") end - elseif t == "-" then + elseif tag_type == "-" then + -- Unescaped output table.insert(buffer, "_b_i = _b_i + 1\n") - table.insert(buffer, "--[[" .. chunk[3] .. "]] _b[_b_i] = _tostring(" .. chunk[2] .. ")\n") + table.insert(buffer, "--[[" .. pos .. "]] _b[_b_i] = _tostring(" .. code .. ")\n") end end end table.insert(buffer, "return _b") - local fn, err = loadstring(table.concat(buffer)) - if not fn then error(err) end + local generated_code = table.concat(buffer) + + -- DEBUG: Uncomment to see generated code + -- print("Generated Lua code:") + -- print(generated_code) + -- print("---") + + local fn, err = loadstring(generated_code) + if not fn then + print("Generated code that failed to compile:") + print(generated_code) + error(err) + end env = env or {} local runtime_env = setmetatable({}, {__index = function(_, k) return env[k] or _G[k] end}) @@ -428,3 +451,142 @@ function password_verify(plain_password, hash_string) return __password_verify(plain_password, hash_string) end + +-- ====================================================================== +-- SESSION FLASH FUNCTIONS +-- ====================================================================== + +function session_flash(key, value) + __response.flash = __response.flash or {} + __response.flash[key] = value +end + +function session_get_flash(key) + -- Check current flash data first + if __response.flash and __response.flash[key] ~= nil then + return __response.flash[key] + end + + -- Check session flash data + if __ctx.session and __ctx.session.flash and __ctx.session.flash[key] ~= nil then + return __ctx.session.flash[key] + end + + return nil +end + +function session_has_flash(key) + -- Check current flash + if __response.flash and __response.flash[key] ~= nil then + return true + end + + -- Check session flash + if __ctx.session and __ctx.session.flash and __ctx.session.flash[key] ~= nil then + return true + end + + return false +end + +function session_get_all_flash() + local flash = {} + + -- Add session flash data first + if __ctx.session and __ctx.session.flash then + for k, v in pairs(__ctx.session.flash) do + flash[k] = v + end + end + + -- Add current response flash (overwrites session flash if same key) + if __response.flash then + for k, v in pairs(__response.flash) do + flash[k] = v + end + end + + return flash +end + +function session_flash_now(key, value) + -- Flash for current request only (not persisted) + _G._current_flash = _G._current_flash or {} + _G._current_flash[key] = value +end + +function session_get_flash_now(key) + return _G._current_flash and _G._current_flash[key] +end + +-- ====================================================================== +-- FLASH HELPER FUNCTIONS +-- ====================================================================== + +function flash_success(message) + session_flash("success", message) +end + +function flash_error(message) + session_flash("error", message) +end + +function flash_warning(message) + session_flash("warning", message) +end + +function flash_info(message) + session_flash("info", message) +end + +function flash_message(type, message) + session_flash(type, message) +end + +-- Get flash messages by type +function get_flash_success() + return session_get_flash("success") +end + +function get_flash_error() + return session_get_flash("error") +end + +function get_flash_warning() + return session_get_flash("warning") +end + +function get_flash_info() + return session_get_flash("info") +end + +-- Check if flash messages exist +function has_flash_success() + return session_has_flash("success") +end + +function has_flash_error() + return session_has_flash("error") +end + +function has_flash_warning() + return session_has_flash("warning") +end + +function has_flash_info() + return session_has_flash("info") +end + +-- Convenience function for redirects with flash +function redirect_with_flash(url, type, message, status) + session_flash(type or "info", message) + http_redirect(url, status) +end + +function redirect_with_success(url, message, status) + redirect_with_flash(url, "success", message, status) +end + +function redirect_with_error(url, message, status) + redirect_with_flash(url, "error", message, status) +end diff --git a/runner/sandbox.go b/runner/sandbox.go index e48762c..ff33d8c 100644 --- a/runner/sandbox.go +++ b/runner/sandbox.go @@ -68,6 +68,7 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (* "cookies": []any{}, "metadata": make(map[string]any), "session": make(map[string]any), + "flash": make(map[string]any), } // Call __execute(script_func, ctx, response) @@ -170,6 +171,11 @@ func (s *Sandbox) buildResponse(luaResp map[string]any, body any) *Response { maps.Copy(resp.SessionData, session) } + // Extract flash data and add to metadata for processing by server + if flash, ok := luaResp["flash"].(map[string]any); ok && len(flash) > 0 { + resp.Metadata["flash"] = flash + } + return resp } diff --git a/sessions/session.go b/sessions/session.go index 28c8a47..daabb88 100644 --- a/sessions/session.go +++ b/sessions/session.go @@ -13,6 +13,8 @@ import ( type Session struct { ID string Data map[string]any + FlashData map[string]any // Flash data for next request + OldFlash map[string]any // Flash data from previous request (to be cleared) CreatedAt time.Time UpdatedAt time.Time LastUsed time.Time @@ -23,7 +25,11 @@ type Session struct { var ( sessionPool = sync.Pool{ New: func() any { - return &Session{Data: make(map[string]any, 8)} + return &Session{ + Data: make(map[string]any, 8), + FlashData: make(map[string]any, 4), + OldFlash: make(map[string]any, 4), + } }, } bufPool = benc.NewBufPool(benc.WithBufferSize(4096)) @@ -35,7 +41,9 @@ func NewSession(id string, maxAge int) *Session { now := time.Now() *s = Session{ ID: id, - Data: s.Data, // Reuse map + Data: s.Data, // Reuse maps + FlashData: s.FlashData, + OldFlash: s.OldFlash, CreatedAt: now, UpdatedAt: now, LastUsed: now, @@ -49,6 +57,12 @@ func (s *Session) Release() { for k := range s.Data { delete(s.Data, k) } + for k := range s.FlashData { + delete(s.FlashData, k) + } + for k := range s.OldFlash { + delete(s.OldFlash, k) + } sessionPool.Put(s) } @@ -70,12 +84,22 @@ func (s *Session) GetTable(key string) map[string]any { return nil } -// GetAll returns a deep copy of all session data +// GetAll returns a deep copy of all session data including flash data func (s *Session) GetAll() map[string]any { - copy := make(map[string]any, len(s.Data)) + copy := make(map[string]any, len(s.Data)+len(s.FlashData)+len(s.OldFlash)) for k, v := range s.Data { copy[k] = deepCopy(v) } + // Include current flash data + for k, v := range s.FlashData { + copy[k] = deepCopy(v) + } + // Include old flash data (still available this request) + for k, v := range s.OldFlash { + if _, exists := copy[k]; !exists { // Don't override new flash + copy[k] = deepCopy(v) + } + } return copy } @@ -103,6 +127,77 @@ func (s *Session) SetTable(key string, table map[string]any) error { return s.SetSafe(key, table) } +// Flash stores a value that will be available for the next request only +func (s *Session) Flash(key string, value any) { + s.FlashData[key] = value + s.UpdatedAt = time.Now() + s.dirty = true +} + +// FlashSafe stores a flash value with validation +func (s *Session) FlashSafe(key string, value any) error { + if err := validate(value); err != nil { + return fmt.Errorf("session.FlashSafe: %w", err) + } + s.Flash(key, value) + return nil +} + +// GetFlash returns a flash value (from either current or old flash) +func (s *Session) GetFlash(key string) any { + // Check current flash first + if v, ok := s.FlashData[key]; ok { + return deepCopy(v) + } + // Check old flash + if v, ok := s.OldFlash[key]; ok { + return deepCopy(v) + } + return nil +} + +// HasFlash checks if a flash key exists +func (s *Session) HasFlash(key string) bool { + _, inNew := s.FlashData[key] + _, inOld := s.OldFlash[key] + return inNew || inOld +} + +// GetAllFlash returns all flash data (both current and old) +func (s *Session) GetAllFlash() map[string]any { + flash := make(map[string]any, len(s.FlashData)+len(s.OldFlash)) + // Add old flash first + for k, v := range s.OldFlash { + flash[k] = deepCopy(v) + } + // Add current flash (overwrites old if same key) + for k, v := range s.FlashData { + flash[k] = deepCopy(v) + } + return flash +} + +// AdvanceFlash moves current flash to old flash and clears old flash +// This should be called at the start of each request +func (s *Session) AdvanceFlash() { + // Clear old flash + for k := range s.OldFlash { + delete(s.OldFlash, k) + } + + // Move current flash to old flash + if len(s.FlashData) > 0 { + for k, v := range s.FlashData { + s.OldFlash[k] = v + } + // Clear current flash + for k := range s.FlashData { + delete(s.FlashData, k) + } + s.dirty = true + } +} + // Delete removes a value from the session func (s *Session) Delete(key string) { delete(s.Data, key) @@ -117,6 +212,15 @@ func (s *Session) Clear() { s.dirty = true } +// ClearFlash removes all flash data +func (s *Session) ClearFlash() { + if len(s.FlashData) > 0 || len(s.OldFlash) > 0 { + s.FlashData = make(map[string]any, 4) + s.OldFlash = make(map[string]any, 4) + s.dirty = true + } +} + // IsExpired checks if the session has expired func (s *Session) IsExpired() bool { return time.Now().After(s.Expiry) @@ -144,6 +248,8 @@ func (s *Session) ResetDirty() { func (s *Session) SizePlain() int { return bstd.SizeString(s.ID) + bstd.SizeMap(s.Data, bstd.SizeString, sizeAny) + + bstd.SizeMap(s.FlashData, bstd.SizeString, sizeAny) + + bstd.SizeMap(s.OldFlash, bstd.SizeString, sizeAny) + bstd.SizeInt64()*4 } @@ -151,6 +257,8 @@ func (s *Session) SizePlain() int { func (s *Session) MarshalPlain(n int, b []byte) int { n = bstd.MarshalString(n, b, s.ID) n = bstd.MarshalMap(n, b, s.Data, bstd.MarshalString, marshalAny) + n = bstd.MarshalMap(n, b, s.FlashData, bstd.MarshalString, marshalAny) + n = bstd.MarshalMap(n, b, s.OldFlash, bstd.MarshalString, marshalAny) n = bstd.MarshalInt64(n, b, s.CreatedAt.Unix()) n = bstd.MarshalInt64(n, b, s.UpdatedAt.Unix()) n = bstd.MarshalInt64(n, b, s.LastUsed.Unix()) @@ -170,6 +278,16 @@ func (s *Session) UnmarshalPlain(n int, b []byte) (int, error) { return n, err } + n, s.FlashData, err = bstd.UnmarshalMap[string, any](n, b, bstd.UnmarshalString, unmarshalAny) + if err != nil { + return n, err + } + + n, s.OldFlash, err = bstd.UnmarshalMap[string, any](n, b, bstd.UnmarshalString, unmarshalAny) + if err != nil { + return n, err + } + var ts int64 for _, t := range []*time.Time{&s.CreatedAt, &s.UpdatedAt, &s.LastUsed, &s.Expiry} { n, ts, err = bstd.UnmarshalInt64(n, b) @@ -430,5 +548,5 @@ func deepEqual(a, b any) bool { // IsEmpty returns true if the session has no data func (s *Session) IsEmpty() bool { - return len(s.Data) == 0 + return len(s.Data) == 0 && len(s.FlashData) == 0 && len(s.OldFlash) == 0 }