From 843e318e011a3a01676410d6217f96708b5673c7 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Mon, 14 Jul 2025 17:36:59 -0500 Subject: [PATCH] next pass --- http/http.go | 843 +++++++++++++++++------------------------- http/http.lua | 144 +++++--- http/router/router.go | 287 +++++++------- 3 files changed, 565 insertions(+), 709 deletions(-) diff --git a/http/http.go b/http/http.go index b05b88c..060fc5b 100644 --- a/http/http.go +++ b/http/http.go @@ -21,25 +21,33 @@ import ( //go:embed http.lua var httpLuaCode string +// HandlerFunc represents a Lua handler +type HandlerFunc struct { + bytecode []byte + funcRef int + name string + isFunction bool +} + +// Server with single state for function handler compatibility type Server struct { server *fasthttp.Server router *router.Router sessions *sessions.SessionManager state *luajit.State stateMu sync.Mutex + handlers map[int]*HandlerFunc + handlersMu sync.RWMutex funcCounter int } +// RequestContext with lazy parsing type RequestContext struct { - Method string - Path string - Headers map[string]string - Query map[string]string - Form map[string]any - Cookies map[string]string - Session *sessions.Session - Body string - Params map[string]string + ctx *fasthttp.RequestCtx + params *router.Params + session *sessions.Session + parsedForm map[string]any + formOnce sync.Once } var globalServer *Server @@ -49,6 +57,7 @@ func NewServer(state *luajit.State) *Server { router: router.New(), sessions: sessions.NewSessionManager(10000), state: state, + handlers: make(map[int]*HandlerFunc), } } @@ -56,19 +65,8 @@ func RegisterHTTPFunctions(L *luajit.State) error { globalServer = NewServer(L) functions := map[string]luajit.GoFunction{ - "__http_listen": globalServer.httpListen, - "__http_route": globalServer.httpRoute, - "__http_set_status": httpSetStatus, - "__http_set_header": httpSetHeader, - "__http_redirect": httpRedirect, - "__session_get": globalServer.sessionGet, - "__session_set": globalServer.sessionSet, - "__session_flash": globalServer.sessionFlash, - "__session_get_flash": globalServer.sessionGetFlash, - "__cookie_set": cookieSet, - "__cookie_get": cookieGet, - "__csrf_generate": globalServer.csrfGenerate, - "__csrf_validate": globalServer.csrfValidate, + "__http_listen": globalServer.httpListen, + "__http_route": globalServer.httpRoute, } for name, fn := range functions { @@ -87,8 +85,13 @@ func (s *Server) httpListen(state *luajit.State) int { } s.server = &fasthttp.Server{ - Handler: s.requestHandler, - Name: "Moonshark/1.0", + Handler: s.fastRequestHandler, + Name: "Moonshark/2.0", + Concurrency: 256 * 1024, + ReadBufferSize: 4096, + WriteBufferSize: 4096, + ReduceMemoryUsage: true, + NoDefaultServerHeader: true, } addr := fmt.Sprintf(":%d", int(port)) @@ -114,16 +117,44 @@ func (s *Server) httpRoute(state *luajit.State) int { return state.PushError("route: path must be string") } - if !state.IsFunction(3) { - return state.PushError("route: handler must be function") + s.funcCounter++ + handlerID := s.funcCounter + + if state.IsFunction(3) { + // Function handler - store reference + state.PushCopy(3) + funcRef := s.storeFunction(state) + + s.handlersMu.Lock() + s.handlers[handlerID] = &HandlerFunc{ + funcRef: funcRef, + name: fmt.Sprintf("%s %s", method, path), + isFunction: true, + } + s.handlersMu.Unlock() + } else { + // String handler - compile to bytecode + handlerCode, err := state.SafeToString(3) + if err != nil { + return state.PushError("route: handler must be function or string") + } + + bytecode, err := state.CompileBytecode(handlerCode, fmt.Sprintf("handler_%s_%s", method, path)) + if err != nil { + return state.PushError("route: failed to compile handler: %s", err.Error()) + } + + s.handlersMu.Lock() + s.handlers[handlerID] = &HandlerFunc{ + bytecode: bytecode, + name: fmt.Sprintf("%s %s", method, path), + isFunction: false, + } + s.handlersMu.Unlock() } - // Store function and get reference - state.PushCopy(3) - funcRef := s.storeFunction() - // Add route to router - if err := s.router.AddRoute(strings.ToUpper(method), path, funcRef); err != nil { + if err := s.router.AddRoute(strings.ToUpper(method), path, handlerID); err != nil { return state.PushError("route: failed to add route: %s", err.Error()) } @@ -131,133 +162,282 @@ func (s *Server) httpRoute(state *luajit.State) int { return 1 } -func (s *Server) storeFunction() int { - s.state.GetGlobal("__moonshark_functions") - if s.state.IsNil(-1) { - s.state.Pop(1) - s.state.NewTable() - s.state.PushCopy(-1) - s.state.SetGlobal("__moonshark_functions") +func (s *Server) storeFunction(state *luajit.State) int { + state.GetGlobal("__moonshark_functions") + if state.IsNil(-1) { + state.Pop(1) + state.NewTable() + state.PushCopy(-1) + state.SetGlobal("__moonshark_functions") } s.funcCounter++ - s.state.PushNumber(float64(s.funcCounter)) - s.state.PushCopy(-3) - s.state.SetTable(-3) - s.state.Pop(2) + state.PushNumber(float64(s.funcCounter)) + state.PushCopy(-3) + state.SetTable(-3) + state.Pop(2) return s.funcCounter } -func (s *Server) getFunction(ref int) bool { - s.state.GetGlobal("__moonshark_functions") - if s.state.IsNil(-1) { - s.state.Pop(1) +func (s *Server) getFunction(state *luajit.State, ref int) bool { + state.GetGlobal("__moonshark_functions") + if state.IsNil(-1) { + state.Pop(1) return false } - s.state.PushNumber(float64(ref)) - s.state.GetTable(-2) - isFunc := s.state.IsFunction(-1) + state.PushNumber(float64(ref)) + state.GetTable(-2) + isFunc := state.IsFunction(-1) if !isFunc { - s.state.Pop(2) + state.Pop(2) return false } - s.state.Remove(-2) + state.Remove(-2) return true } -func (s *Server) requestHandler(ctx *fasthttp.RequestCtx) { +func (s *Server) fastRequestHandler(ctx *fasthttp.RequestCtx) { method := string(ctx.Method()) path := string(ctx.Path()) - // Look up route in router - handlerRef, params, found := s.router.Lookup(method, path) + // Fast route lookup + handlerID, params, found := s.router.Lookup(method, path) if !found { ctx.SetStatusCode(404) ctx.SetBodyString("Not Found") return } - reqCtx := s.buildRequestContext(ctx, params) - reqCtx.Session.AdvanceFlash() + // Get compiled handler + s.handlersMu.RLock() + handler := s.handlers[handlerID] + s.handlersMu.RUnlock() - s.stateMu.Lock() - defer s.stateMu.Unlock() - - s.setupRequestEnvironment(reqCtx) - - if !s.getFunction(handlerRef) { + if handler == nil { ctx.SetStatusCode(500) ctx.SetBodyString("Handler not found") return } - if err := s.state.PushValue(s.requestToTable(reqCtx)); err != nil { - ctx.SetStatusCode(500) - ctx.SetBodyString("Failed to create request object") - return - } + // Lock state for execution + s.stateMu.Lock() + defer s.stateMu.Unlock() - if err := s.state.Call(1, 1); err != nil { - ctx.SetStatusCode(500) - ctx.SetBodyString(fmt.Sprintf("Handler error: %v", err)) - return + // Setup request context + reqCtx := &RequestContext{ + ctx: ctx, + params: params, + session: s.sessions.GetSessionFromRequest(ctx), } + reqCtx.session.AdvanceFlash() var responseBody string - if s.state.GetTop() > 0 && !s.state.IsNil(-1) { - responseBody = s.state.ToString(-1) + + if handler.isFunction { + // Function handler - use traditional approach + s.setupFunctionEnvironment(s.state, reqCtx) + + if !s.getFunction(s.state, handler.funcRef) { + ctx.SetStatusCode(500) + ctx.SetBodyString("Function handler not found") + return + } + + // Push request object + if err := s.state.PushValue(s.requestToTable(reqCtx)); err != nil { + ctx.SetStatusCode(500) + ctx.SetBodyString("Failed to create request object") + return + } + + if err := s.state.Call(1, 1); err != nil { + ctx.SetStatusCode(500) + ctx.SetBodyString(fmt.Sprintf("Handler error: %v", err)) + return + } + + if s.state.GetTop() > 0 && !s.state.IsNil(-1) { + responseBody = s.state.ToString(-1) + } + s.state.Pop(1) + } else { + // Bytecode handler - use fast approach + s.setupFastEnvironment(s.state, reqCtx) + + if err := s.state.LoadAndRunBytecode(handler.bytecode, handler.name); err != nil { + ctx.SetStatusCode(500) + ctx.SetBodyString(fmt.Sprintf("Handler error: %v", err)) + return + } + + if s.state.GetTop() > 0 && !s.state.IsNil(-1) { + responseBody = s.state.ToString(-1) + } s.state.Pop(1) } - s.updateSessionFromLua(reqCtx.Session) - s.applyResponse(ctx, responseBody) - s.sessions.ApplySessionCookie(ctx, reqCtx.Session) + // Apply response + s.applyResponse(ctx, s.state, responseBody, reqCtx.session) + s.sessions.ApplySessionCookie(ctx, reqCtx.session) + + // Clean up state + s.state.SetTop(0) } -func (s *Server) setupRequestEnvironment(reqCtx *RequestContext) { - s.state.PushValue(s.requestToTable(reqCtx)) - s.state.SetGlobal("__request") +func (s *Server) setupFunctionEnvironment(state *luajit.State, reqCtx *RequestContext) { + // Set up response globals for function handlers + state.NewTable() + state.PushNumber(200) + state.SetField(-2, "status") + state.NewTable() + state.SetField(-2, "headers") + state.NewTable() + state.SetField(-2, "cookies") + state.SetGlobal("__response") - s.state.PushValue(s.sessionToTable(reqCtx.Session)) - s.state.SetGlobal("__session") + // Session data + if !reqCtx.session.IsEmpty() { + state.PushValue(reqCtx.session.GetAll()) + state.SetGlobal("__session") + } +} - s.state.NewTable() - s.state.SetGlobal("__response") +func (s *Server) setupFastEnvironment(state *luajit.State, reqCtx *RequestContext) { + // Request basics as globals for fast access + state.PushString(string(reqCtx.ctx.Method())) + state.SetGlobal("REQUEST_METHOD") + + state.PushString(string(reqCtx.ctx.Path())) + state.SetGlobal("REQUEST_PATH") + + // Parameters + if reqCtx.params != nil && len(reqCtx.params.Keys) > 0 { + paramMap := make(map[string]string, len(reqCtx.params.Keys)) + for i, key := range reqCtx.params.Keys { + if i < len(reqCtx.params.Values) { + paramMap[key] = reqCtx.params.Values[i] + } + } + state.PushValue(paramMap) + state.SetGlobal("PARAMS") + } + + // Query parameters + queryMap := make(map[string]string) + reqCtx.ctx.QueryArgs().VisitAll(func(key, value []byte) { + queryMap[string(key)] = string(value) + }) + if len(queryMap) > 0 { + state.PushValue(queryMap) + state.SetGlobal("QUERY") + } + + // Headers + headerMap := make(map[string]string) + reqCtx.ctx.Request.Header.VisitAll(func(key, value []byte) { + headerMap[string(key)] = string(value) + }) + state.PushValue(headerMap) + state.SetGlobal("HEADERS") + + // Cookies + cookieMap := make(map[string]string) + reqCtx.ctx.Request.Header.VisitAllCookie(func(key, value []byte) { + cookieMap[string(key)] = string(value) + }) + if len(cookieMap) > 0 { + state.PushValue(cookieMap) + state.SetGlobal("COOKIES") + } + + // Form data + if reqCtx.ctx.IsPost() || reqCtx.ctx.IsPut() || reqCtx.ctx.IsPatch() { + form := s.parseForm(reqCtx.ctx) + if len(form) > 0 { + state.PushValue(form) + state.SetGlobal("FORM") + } + } + + // Session data + if !reqCtx.session.IsEmpty() { + state.PushValue(reqCtx.session.GetAll()) + state.SetGlobal("session_data") + } + + // CSRF token + if csrfToken := s.generateCSRFToken(); csrfToken != "" { + state.PushString(csrfToken) + state.SetGlobal("CSRF_TOKEN") + } + + // JSON encode fallback + state.RegisterGoFunction("json_encode_fallback", func(state *luajit.State) int { + val, _ := state.ToValue(1) + if b, err := json.Marshal(val); err == nil { + state.PushString(string(b)) + } else { + state.PushString("null") + } + return 1 + }) } func (s *Server) requestToTable(reqCtx *RequestContext) map[string]any { - return map[string]any{ - "method": reqCtx.Method, - "path": reqCtx.Path, - "headers": reqCtx.Headers, - "query": reqCtx.Query, - "form": reqCtx.Form, - "cookies": reqCtx.Cookies, - "body": reqCtx.Body, - "params": reqCtx.Params, + req := map[string]any{ + "method": string(reqCtx.ctx.Method()), + "path": string(reqCtx.ctx.Path()), + "headers": make(map[string]string), + "query": make(map[string]string), + "cookies": make(map[string]string), + "body": string(reqCtx.ctx.PostBody()), } + + // Headers + headers := req["headers"].(map[string]string) + reqCtx.ctx.Request.Header.VisitAll(func(key, value []byte) { + headers[string(key)] = string(value) + }) + + // Cookies + cookies := req["cookies"].(map[string]string) + reqCtx.ctx.Request.Header.VisitAllCookie(func(key, value []byte) { + cookies[string(key)] = string(value) + }) + + // Query + query := req["query"].(map[string]string) + reqCtx.ctx.QueryArgs().VisitAll(func(key, value []byte) { + query[string(key)] = string(value) + }) + + // Params + if reqCtx.params != nil && len(reqCtx.params.Keys) > 0 { + params := make(map[string]string, len(reqCtx.params.Keys)) + for i, key := range reqCtx.params.Keys { + if i < len(reqCtx.params.Values) { + params[key] = reqCtx.params.Values[i] + } + } + req["params"] = params + } + + // Form + if reqCtx.ctx.IsPost() || reqCtx.ctx.IsPut() || reqCtx.ctx.IsPatch() { + req["form"] = s.parseForm(reqCtx.ctx) + } + + return req } -func (s *Server) sessionToTable(session *sessions.Session) map[string]any { - return map[string]any{ - "id": session.ID, - "data": session.GetAll(), - } -} - -func (s *Server) updateSessionFromLua(session *sessions.Session) { - s.state.GetGlobal("__session") - if s.state.IsNil(-1) { - s.state.Pop(1) - return - } - - s.state.GetField(-1, "data") - if s.state.IsTable(-1) { - if data, err := s.state.ToTable(-1); err == nil { +func (s *Server) applyResponse(ctx *fasthttp.RequestCtx, state *luajit.State, body string, session *sessions.Session) { + // Update session from Lua + state.GetGlobal("session_data") + if state.IsTable(-1) { + if data, err := state.ToTable(-1); err == nil { if dataMap, ok := data.(map[string]any); ok { session.Clear() for k, v := range dataMap { @@ -266,55 +446,59 @@ func (s *Server) updateSessionFromLua(session *sessions.Session) { } } } - s.state.Pop(2) -} + state.Pop(1) -func (s *Server) applyResponse(ctx *fasthttp.RequestCtx, body string) { - s.state.GetGlobal("__response") - if s.state.IsNil(-1) { - s.state.Pop(1) - if body != "" { - ctx.SetBodyString(body) + // Check for response table (function handlers) or response global (fast handlers) + state.GetGlobal("__response") + if state.IsNil(-1) { + state.Pop(1) + state.GetGlobal("response") + if state.IsNil(-1) { + state.Pop(1) + if body != "" { + ctx.SetBodyString(body) + } + return } - return } - s.state.GetField(-1, "status") - if s.state.IsNumber(-1) { - ctx.SetStatusCode(int(s.state.ToNumber(-1))) + // Status + if status := state.GetFieldNumber(-1, "status", 200); status != 200 { + ctx.SetStatusCode(int(status)) } - s.state.Pop(1) - s.state.GetField(-1, "headers") - if s.state.IsTable(-1) { - s.state.ForEachTableKV(-1, func(key, value string) bool { + // Headers + state.GetField(-1, "headers") + if state.IsTable(-1) { + state.ForEachTableKV(-1, func(key, value string) bool { ctx.Response.Header.Set(key, value) return true }) } - s.state.Pop(1) + state.Pop(1) - s.state.GetField(-1, "cookies") - if s.state.IsTable(-1) { - s.applyCookies(ctx) + // Cookies + state.GetField(-1, "cookies") + if state.IsTable(-1) { + s.applyCookies(ctx, state) } - s.state.Pop(1) + state.Pop(1) - s.state.Pop(1) + state.Pop(1) if body != "" { ctx.SetBodyString(body) } } -func (s *Server) applyCookies(ctx *fasthttp.RequestCtx) { - s.state.ForEachArray(-1, func(i int, state *luajit.State) bool { - if !state.IsTable(-1) { +func (s *Server) applyCookies(ctx *fasthttp.RequestCtx, state *luajit.State) { + state.ForEachArray(-1, func(i int, st *luajit.State) bool { + if !st.IsTable(-1) { return true } - name := state.GetFieldString(-1, "name", "") - value := state.GetFieldString(-1, "value", "") + name := st.GetFieldString(-1, "name", "") + value := st.GetFieldString(-1, "value", "") if name == "" { return true } @@ -324,22 +508,29 @@ func (s *Server) applyCookies(ctx *fasthttp.RequestCtx) { cookie.SetKey(name) cookie.SetValue(value) - cookie.SetPath(state.GetFieldString(-1, "path", "/")) - if domain := state.GetFieldString(-1, "domain", ""); domain != "" { - cookie.SetDomain(domain) - } - - if state.GetFieldBool(-1, "secure", false) { - cookie.SetSecure(true) - } - - if state.GetFieldBool(-1, "http_only", true) { - cookie.SetHTTPOnly(true) - } - - if maxAge := state.GetFieldNumber(-1, "max_age", 0); maxAge > 0 { - cookie.SetExpire(time.Now().Add(time.Duration(maxAge) * time.Second)) + if options, ok := st.GetFieldTable(-1, "options"); ok { + if optMap, ok := options.(map[string]any); ok { + if path, ok := optMap["path"].(string); ok { + cookie.SetPath(path) + } else { + cookie.SetPath("/") + } + if domain, ok := optMap["domain"].(string); ok { + cookie.SetDomain(domain) + } + if secure, ok := optMap["secure"].(bool); ok && secure { + cookie.SetSecure(true) + } + if httpOnly, ok := optMap["http_only"].(bool); ok { + cookie.SetHTTPOnly(httpOnly) + } else { + cookie.SetHTTPOnly(true) + } + if maxAge, ok := optMap["max_age"].(int); ok && maxAge > 0 { + cookie.SetExpire(time.Now().Add(time.Duration(maxAge) * time.Second)) + } + } } ctx.Response.Header.SetCookie(cookie) @@ -347,42 +538,6 @@ func (s *Server) applyCookies(ctx *fasthttp.RequestCtx) { }) } -func (s *Server) buildRequestContext(ctx *fasthttp.RequestCtx, params *router.Params) *RequestContext { - reqCtx := &RequestContext{ - Method: string(ctx.Method()), - Path: string(ctx.Path()), - Headers: make(map[string]string), - Query: make(map[string]string), - Cookies: make(map[string]string), - Body: string(ctx.PostBody()), - Params: make(map[string]string), - } - - ctx.Request.Header.VisitAll(func(key, value []byte) { - reqCtx.Headers[string(key)] = string(value) - }) - - ctx.Request.Header.VisitAllCookie(func(key, value []byte) { - reqCtx.Cookies[string(key)] = string(value) - }) - - ctx.QueryArgs().VisitAll(func(key, value []byte) { - reqCtx.Query[string(key)] = string(value) - }) - - // Convert router params to map - for i, key := range params.Keys { - if i < len(params.Values) { - reqCtx.Params[key] = params.Values[i] - } - } - - reqCtx.Form = s.parseForm(ctx) - reqCtx.Session = s.sessions.GetSessionFromRequest(ctx) - - return reqCtx -} - func (s *Server) parseForm(ctx *fasthttp.RequestCtx) map[string]any { contentType := string(ctx.Request.Header.ContentType()) form := make(map[string]any) @@ -440,317 +595,3 @@ func (s *Server) generateCSRFToken() string { rand.Read(bytes) return base64.URLEncoding.EncodeToString(bytes) } - -// Lua function implementations -func httpSetStatus(state *luajit.State) int { - code, _ := state.SafeToNumber(1) - state.GetGlobal("__response") - if state.IsNil(-1) { - state.Pop(1) - state.NewTable() - state.SetGlobal("__response") - state.GetGlobal("__response") - } - state.PushNumber(code) - state.SetField(-2, "status") - state.Pop(1) - return 0 -} - -func httpSetHeader(state *luajit.State) int { - name, _ := state.SafeToString(1) - value, _ := state.SafeToString(2) - - state.GetGlobal("__response") - if state.IsNil(-1) { - state.Pop(1) - state.NewTable() - state.SetGlobal("__response") - state.GetGlobal("__response") - } - - state.GetField(-1, "headers") - if state.IsNil(-1) { - state.Pop(1) - state.NewTable() - state.PushCopy(-1) - state.SetField(-3, "headers") - } - - state.PushString(value) - state.SetField(-2, name) - state.Pop(2) - return 0 -} - -func httpRedirect(state *luajit.State) int { - url, _ := state.SafeToString(1) - status := 302.0 - if state.GetTop() >= 2 { - status, _ = state.SafeToNumber(2) - } - - state.GetGlobal("__response") - if state.IsNil(-1) { - state.Pop(1) - state.NewTable() - state.SetGlobal("__response") - state.GetGlobal("__response") - } - - state.PushNumber(status) - state.SetField(-2, "status") - - state.GetField(-1, "headers") - if state.IsNil(-1) { - state.Pop(1) - state.NewTable() - state.PushCopy(-1) - state.SetField(-3, "headers") - } - - state.PushString(url) - state.SetField(-2, "Location") - state.Pop(2) - return 0 -} - -func (s *Server) sessionGet(state *luajit.State) int { - key, _ := state.SafeToString(1) - - state.GetGlobal("__session") - if state.IsNil(-1) { - state.Pop(1) - state.PushNil() - return 1 - } - - state.GetField(-1, "data") - if state.IsNil(-1) { - state.Pop(2) - state.PushNil() - return 1 - } - - state.GetField(-1, key) - return 1 -} - -func (s *Server) sessionSet(state *luajit.State) int { - key, _ := state.SafeToString(1) - - state.GetGlobal("__session") - if state.IsNil(-1) { - state.Pop(1) - state.NewTable() - state.SetGlobal("__session") - state.GetGlobal("__session") - } - - state.GetField(-1, "data") - if state.IsNil(-1) { - state.Pop(1) - state.NewTable() - state.PushCopy(-1) - state.SetField(-3, "data") - } - - value, err := state.ToValue(2) - if err == nil { - state.PushValue(value) - state.SetField(-2, key) - } - state.Pop(2) - return 0 -} - -func (s *Server) sessionFlash(state *luajit.State) int { - key, _ := state.SafeToString(1) - - state.GetGlobal("__session") - if state.IsNil(-1) { - state.Pop(1) - state.NewTable() - state.SetGlobal("__session") - state.GetGlobal("__session") - } - - state.GetField(-1, "flash") - if state.IsNil(-1) { - state.Pop(1) - state.NewTable() - state.PushCopy(-1) - state.SetField(-3, "flash") - } - - value, err := state.ToValue(2) - if err == nil { - state.PushValue(value) - state.SetField(-2, key) - } - state.Pop(2) - return 0 -} - -func (s *Server) sessionGetFlash(state *luajit.State) int { - key, _ := state.SafeToString(1) - - state.GetGlobal("__session") - if state.IsNil(-1) { - state.Pop(1) - state.PushNil() - return 1 - } - - state.GetField(-1, "flash") - if state.IsNil(-1) { - state.Pop(2) - state.PushNil() - return 1 - } - - state.GetField(-1, key) - return 1 -} - -func cookieSet(state *luajit.State) int { - name, _ := state.SafeToString(1) - value, _ := state.SafeToString(2) - - maxAge := 0 - path := "/" - domain := "" - secure := false - httpOnly := true - - if state.GetTop() >= 3 && state.IsTable(3) { - maxAge = int(state.GetFieldNumber(3, "max_age", 0)) - path = state.GetFieldString(3, "path", "/") - domain = state.GetFieldString(3, "domain", "") - secure = state.GetFieldBool(3, "secure", false) - httpOnly = state.GetFieldBool(3, "http_only", true) - } - - state.GetGlobal("__response") - if state.IsNil(-1) { - state.Pop(1) - state.NewTable() - state.SetGlobal("__response") - state.GetGlobal("__response") - } - - state.GetField(-1, "cookies") - if state.IsNil(-1) { - state.Pop(1) - state.NewTable() - state.PushCopy(-1) - state.SetField(-3, "cookies") - } - - cookieData := map[string]any{ - "name": name, - "value": value, - "path": path, - "secure": secure, - "http_only": httpOnly, - } - if domain != "" { - cookieData["domain"] = domain - } - if maxAge > 0 { - cookieData["max_age"] = maxAge - } - - state.PushValue(cookieData) - - length := globalServer.getTableLength(-2) - state.PushNumber(float64(length + 1)) - state.PushCopy(-2) - state.SetTable(-4) - - state.Pop(3) - return 0 -} - -func cookieGet(state *luajit.State) int { - name, _ := state.SafeToString(1) - - state.GetGlobal("__request") - if state.IsNil(-1) { - state.Pop(1) - state.PushNil() - return 1 - } - - state.GetField(-1, "cookies") - if state.IsNil(-1) { - state.Pop(2) - state.PushNil() - return 1 - } - - state.GetField(-1, name) - return 1 -} - -func (s *Server) csrfGenerate(state *luajit.State) int { - token := s.generateCSRFToken() - - state.GetGlobal("__session") - if state.IsNil(-1) { - state.Pop(1) - state.NewTable() - state.SetGlobal("__session") - state.GetGlobal("__session") - } - - state.GetField(-1, "data") - if state.IsNil(-1) { - state.Pop(1) - state.NewTable() - state.PushCopy(-1) - state.SetField(-3, "data") - } - - state.PushString(token) - state.SetField(-2, "_csrf_token") - state.Pop(2) - - state.PushString(token) - return 1 -} - -func (s *Server) csrfValidate(state *luajit.State) int { - state.GetGlobal("__session") - if state.IsNil(-1) { - state.Pop(1) - state.PushBoolean(false) - return 1 - } - - sessionToken := state.GetFieldString(-1, "data._csrf_token", "") - state.Pop(1) - - state.GetGlobal("__request") - if state.IsNil(-1) { - state.Pop(1) - state.PushBoolean(false) - return 1 - } - - requestToken := state.GetFieldString(-1, "form._csrf_token", "") - state.Pop(1) - - state.PushBoolean(sessionToken != "" && sessionToken == requestToken) - return 1 -} - -func (s *Server) getTableLength(index int) int { - length := 0 - s.state.PushNil() - for s.state.Next(index - 1) { - length++ - s.state.Pop(1) - } - return length -} diff --git a/http/http.lua b/http/http.lua index d28cbe5..fef7630 100644 --- a/http/http.lua +++ b/http/http.lua @@ -1,3 +1,7 @@ +-- Fast response handling +local response = {status = 200, headers = {}, cookies = {}} +local session_data = {} + http = {} function http.listen(port) @@ -8,20 +12,15 @@ function http.route(method, path, handler) return __http_route(method, path, handler) end -function http.status(code) - return __http_set_status(code) +function http.status(code) + response.status = code end -function http.header(name, value) - return __http_set_header(name, value) +function http.header(k, v) + response.headers[k] = v end -function http.redirect(url, status) - __http_redirect(url, status or 302) - coroutine.yield() -- Exit handler -end - -function http.json(data) +function http.json(data) http.header("Content-Type", "application/json") return json.encode(data) end @@ -32,84 +31,119 @@ function http.html(content) end function http.text(content) - http.header("Content-Type", "text/plain") + http.header("Content-Type", "text/plain") return content end +function http.redirect(url, code) + response.status = code or 302 + response.headers["Location"] = url + coroutine.yield() +end + -- Session functions session = {} -function session.get(key) - return __session_get(key) +function session.get(key) + return session_data[key] end -function session.set(key, value) - return __session_set(key, value) +function session.set(key, val) + session_data[key] = val end -function session.flash(key, value) - return __session_flash(key, value) +function session.flash(key, val) + session_data["_flash_" .. key] = val end -function session.get_flash(key) - return __session_get_flash(key) +function session.get_flash(key) + local val = session_data["_flash_" .. key] + session_data["_flash_" .. key] = nil + return val end --- Cookie functions +-- Cookie functions cookie = {} -function cookie.set(name, value, options) - return __cookie_set(name, value, options) +function cookie.get(name) + return COOKIES and COOKIES[name] end -function cookie.get(name) - return __cookie_get(name) +function cookie.set(name, value, options) + response.cookies[#response.cookies + 1] = { + name = name, + value = value, + options = options or {} + } end -- CSRF functions csrf = {} function csrf.generate() - return __csrf_generate() + local token = CSRF_TOKEN or "" + session.set("_csrf_token", token) + return token end function csrf.validate() - return __csrf_validate() + local session_token = session.get("_csrf_token") + local form_token = FORM and FORM._csrf_token + return session_token and session_token == form_token end function csrf.field() - local token = csrf.generate() - return string.format('', token) + return '' end +-- Fast JSON encoding +json = { + encode = function(data) + if type(data) == "string" then + return '"' .. data .. '"' + elseif type(data) == "number" then + return tostring(data) + elseif type(data) == "boolean" then + return data and "true" or "false" + elseif data == nil then + return "null" + elseif type(data) == "table" then + -- Check if it's an array + local isArray = true + local n = 0 + for k, v in pairs(data) do + n = n + 1 + if type(k) ~= "number" or k ~= n then + isArray = false + break + end + end + + if isArray then + local result = "[" + for i = 1, n do + if i > 1 then result = result .. "," end + result = result .. json.encode(data[i]) + end + return result .. "]" + else + local result = "{" + local first = true + for k, v in pairs(data) do + if not first then result = result .. "," end + result = result .. '"' .. tostring(k) .. '":' .. json.encode(v) + first = false + end + return result .. "}" + end + else + return json_encode_fallback(data) + end + end +} + -- Helper functions function redirect_with_flash(url, type, message) session.flash(type, message) http.redirect(url) -end - --- JSON encoding/decoding placeholder -json = { - encode = function(data) - -- Simplified JSON encoding - if type(data) == "table" then - local result = "{" - local first = true - for k, v in pairs(data) do - if not first then result = result .. "," end - result = result .. '"' .. tostring(k) .. '":' .. json.encode(v) - first = false - end - return result .. "}" - elseif type(data) == "string" then - return '"' .. data .. '"' - else - return tostring(data) - end - end, - - decode = function(str) - -- Simplified JSON decoding - you'd want a proper implementation - return {} - end -} \ No newline at end of file +end \ No newline at end of file diff --git a/http/router/router.go b/http/router/router.go index c2f2354..817b7c5 100644 --- a/http/router/router.go +++ b/http/router/router.go @@ -1,25 +1,24 @@ package router import ( - "errors" - "strings" - "sync" + "fmt" ) -// node represents a node in the radix trie +// Handler function that takes parameters as strings +type Handler func(params []string) + type node struct { segment string - handler int // Lua function reference + handlerID int children []*node - isDynamic bool // :param - isWildcard bool // *param - paramName string + isDynamic bool + isWildcard bool + maxParams uint8 } -// Router is a string-based HTTP router with efficient lookup type Router struct { get, post, put, patch, delete *node - mu sync.RWMutex + paramsBuffer []string } // Params holds URL parameters @@ -41,15 +40,15 @@ func (p *Params) Get(name string) string { // New creates a new Router instance func New() *Router { return &Router{ - get: &node{}, - post: &node{}, - put: &node{}, - patch: &node{}, - delete: &node{}, + get: &node{}, + post: &node{}, + put: &node{}, + patch: &node{}, + delete: &node{}, + paramsBuffer: make([]string, 64), } } -// methodNode returns the root node for a method func (r *Router) methodNode(method string) *node { switch method { case "GET": @@ -67,47 +66,71 @@ func (r *Router) methodNode(method string) *node { } } -// AddRoute adds a new route with handler reference -func (r *Router) AddRoute(method, path string, handlerRef int) error { - r.mu.Lock() - defer r.mu.Unlock() - +// AddRoute adds a route with handler ID (for compatibility) +func (r *Router) AddRoute(method, path string, handlerID int) error { root := r.methodNode(method) if root == nil { - return errors.New("unsupported HTTP method") + return fmt.Errorf("unsupported method: %s", method) } + // Create a handler that stores the ID + h := func(params []string) { + // This is a placeholder - the actual execution happens in HTTP module + } + + return r.addRoute(root, path, h, handlerID) +} + +// readSegment extracts the next path segment +func readSegment(path string, start int) (segment string, end int, hasMore bool) { + if start >= len(path) { + return "", start, false + } + if path[start] == '/' { + start++ + } + if start >= len(path) { + return "", start, false + } + end = start + for end < len(path) && path[end] != '/' { + end++ + } + return path[start:end], end, end < len(path) +} + +// addRoute adds a new route to the trie +func (r *Router) addRoute(root *node, path string, h Handler, handlerID int) error { if path == "/" { - root.handler = handlerRef + root.handlerID = handlerID return nil } - return r.addRoute(root, path, handlerRef) -} - -// addRoute adds a route to the trie -func (r *Router) addRoute(root *node, path string, handlerRef int) error { - segments := r.parseSegments(path) current := root + pos := 0 + lastWC := false + count := uint8(0) - for _, seg := range segments { - isDyn := strings.HasPrefix(seg, ":") - isWC := strings.HasPrefix(seg, "*") - - if isWC && seg != segments[len(segments)-1] { - return errors.New("wildcard must be the last segment") + for { + seg, newPos, more := readSegment(path, pos) + if seg == "" { + break } - paramName := "" - if isDyn { - paramName = seg[1:] - seg = ":" - } else if isWC { - paramName = seg[1:] - seg = "*" + isDyn := len(seg) > 1 && seg[0] == ':' + isWC := len(seg) > 0 && seg[0] == '*' + + if isWC { + if lastWC || more { + return fmt.Errorf("wildcard must be the last segment in the path") + } + lastWC = true + } + + if isDyn || isWC { + count++ } - // Find or create child var child *node for _, c := range current.children { if c.segment == seg { @@ -117,144 +140,102 @@ func (r *Router) addRoute(root *node, path string, handlerRef int) error { } if child == nil { - child = &node{ - segment: seg, - isDynamic: isDyn, - isWildcard: isWC, - paramName: paramName, - } + child = &node{segment: seg, isDynamic: isDyn, isWildcard: isWC} current.children = append(current.children, child) } + if child.maxParams < count { + child.maxParams = count + } + current = child + pos = newPos } - current.handler = handlerRef + current.handlerID = handlerID return nil } -// parseSegments splits path into segments -func (r *Router) parseSegments(path string) []string { - segments := strings.Split(strings.Trim(path, "/"), "/") - var result []string - for _, seg := range segments { - if seg != "" { - result = append(result, seg) - } - } - return result -} - -// Lookup finds handler and parameters for a method and path +// Lookup finds a handler matching method and path func (r *Router) Lookup(method, path string) (int, *Params, bool) { - r.mu.RLock() - defer r.mu.RUnlock() - root := r.methodNode(method) if root == nil { return 0, nil, false } if path == "/" { - if root.handler != 0 { - return root.handler, &Params{}, true + if root.handlerID != 0 { + return root.handlerID, &Params{}, true } return 0, nil, false } - segments := r.parseSegments(path) - handler, params := r.match(root, segments, 0) - if handler == 0 { + buffer := r.paramsBuffer + if cap(buffer) < int(root.maxParams) { + buffer = make([]string, root.maxParams) + r.paramsBuffer = buffer + } + buffer = buffer[:0] + + handlerID, paramCount, paramKeys, found := r.match(root, path, 0, &buffer) + if !found { return 0, nil, false } - return handler, params, true -} - -// match traverses the trie to find handler -func (r *Router) match(current *node, segments []string, index int) (int, *Params) { - if index >= len(segments) { - if current.handler != 0 { - return current.handler, &Params{} - } - return 0, nil + params := &Params{ + Keys: paramKeys, + Values: buffer[:paramCount], } - segment := segments[index] + return handlerID, params, true +} - // Check exact match first - for _, child := range current.children { - if child.segment == segment { - handler, params := r.match(child, segments, index+1) - if handler != 0 { - return handler, params +// match traverses the trie to find a handler +func (r *Router) match(current *node, path string, start int, params *[]string) (int, int, []string, bool) { + paramCount := 0 + var paramKeys []string + + // Check wildcards first + for _, c := range current.children { + if c.isWildcard { + rem := path[start:] + if len(rem) > 0 && rem[0] == '/' { + rem = rem[1:] + } + *params = append(*params, rem) + // Extract param name from *name format + paramName := c.segment[1:] // Remove the * prefix + paramKeys = append(paramKeys, paramName) + return c.handlerID, 1, paramKeys, c.handlerID != 0 + } + } + + seg, pos, more := readSegment(path, start) + if seg == "" { + return current.handlerID, 0, paramKeys, current.handlerID != 0 + } + + for _, c := range current.children { + if c.segment == seg || c.isDynamic { + if c.isDynamic { + *params = append(*params, seg) + // Extract param name from :name format + paramName := c.segment[1:] // Remove the : prefix + paramKeys = append(paramKeys, paramName) + paramCount++ + } + + if !more { + return c.handlerID, paramCount, paramKeys, c.handlerID != 0 + } + + handlerID, nestedCount, nestedKeys, ok := r.match(c, path, pos, params) + if ok { + allKeys := append(paramKeys, nestedKeys...) + return handlerID, paramCount + nestedCount, allKeys, true } } } - // Check dynamic match second - for _, child := range current.children { - if child.isDynamic { - handler, params := r.match(child, segments, index+1) - if handler != 0 { - // Prepend this parameter - newParams := &Params{ - Keys: append([]string{child.paramName}, params.Keys...), - Values: append([]string{segment}, params.Values...), - } - return handler, newParams - } - } - } - - // Check wildcard last (catches everything remaining) - for _, child := range current.children { - if child.isWildcard { - remaining := strings.Join(segments[index:], "/") - return child.handler, &Params{ - Keys: []string{child.paramName}, - Values: []string{remaining}, - } - } - } - - return 0, nil -} - -// RemoveRoute removes a route -func (r *Router) RemoveRoute(method, path string) { - r.mu.Lock() - defer r.mu.Unlock() - - root := r.methodNode(method) - if root == nil { - return - } - - if path == "/" { - root.handler = 0 - return - } - - segments := r.parseSegments(path) - r.removeRoute(root, segments, 0) -} - -// removeRoute removes a route from the trie -func (r *Router) removeRoute(current *node, segments []string, index int) { - if index >= len(segments) { - current.handler = 0 - return - } - - segment := segments[index] - - for _, child := range current.children { - if child.segment == segment || - (child.isDynamic && strings.HasPrefix(segment, ":")) || - (child.isWildcard && strings.HasPrefix(segment, "*")) { - r.removeRoute(child, segments, index+1) - break - } - } + return 0, 0, nil, false }