next pass

This commit is contained in:
Sky Johnson 2025-07-14 17:36:59 -05:00
parent bb06e2431d
commit 843e318e01
3 changed files with 565 additions and 709 deletions

View File

@ -21,25 +21,33 @@ import (
//go:embed http.lua //go:embed http.lua
var httpLuaCode string var httpLuaCode string
// HandlerFunc represents a Lua handler
type HandlerFunc struct {
bytecode []byte
funcRef int
name string
isFunction bool
}
// Server with single state for function handler compatibility
type Server struct { type Server struct {
server *fasthttp.Server server *fasthttp.Server
router *router.Router router *router.Router
sessions *sessions.SessionManager sessions *sessions.SessionManager
state *luajit.State state *luajit.State
stateMu sync.Mutex stateMu sync.Mutex
handlers map[int]*HandlerFunc
handlersMu sync.RWMutex
funcCounter int funcCounter int
} }
// RequestContext with lazy parsing
type RequestContext struct { type RequestContext struct {
Method string ctx *fasthttp.RequestCtx
Path string params *router.Params
Headers map[string]string session *sessions.Session
Query map[string]string parsedForm map[string]any
Form map[string]any formOnce sync.Once
Cookies map[string]string
Session *sessions.Session
Body string
Params map[string]string
} }
var globalServer *Server var globalServer *Server
@ -49,6 +57,7 @@ func NewServer(state *luajit.State) *Server {
router: router.New(), router: router.New(),
sessions: sessions.NewSessionManager(10000), sessions: sessions.NewSessionManager(10000),
state: state, state: state,
handlers: make(map[int]*HandlerFunc),
} }
} }
@ -56,19 +65,8 @@ func RegisterHTTPFunctions(L *luajit.State) error {
globalServer = NewServer(L) globalServer = NewServer(L)
functions := map[string]luajit.GoFunction{ functions := map[string]luajit.GoFunction{
"__http_listen": globalServer.httpListen, "__http_listen": globalServer.httpListen,
"__http_route": globalServer.httpRoute, "__http_route": globalServer.httpRoute,
"__http_set_status": httpSetStatus,
"__http_set_header": httpSetHeader,
"__http_redirect": httpRedirect,
"__session_get": globalServer.sessionGet,
"__session_set": globalServer.sessionSet,
"__session_flash": globalServer.sessionFlash,
"__session_get_flash": globalServer.sessionGetFlash,
"__cookie_set": cookieSet,
"__cookie_get": cookieGet,
"__csrf_generate": globalServer.csrfGenerate,
"__csrf_validate": globalServer.csrfValidate,
} }
for name, fn := range functions { for name, fn := range functions {
@ -87,8 +85,13 @@ func (s *Server) httpListen(state *luajit.State) int {
} }
s.server = &fasthttp.Server{ s.server = &fasthttp.Server{
Handler: s.requestHandler, Handler: s.fastRequestHandler,
Name: "Moonshark/1.0", Name: "Moonshark/2.0",
Concurrency: 256 * 1024,
ReadBufferSize: 4096,
WriteBufferSize: 4096,
ReduceMemoryUsage: true,
NoDefaultServerHeader: true,
} }
addr := fmt.Sprintf(":%d", int(port)) addr := fmt.Sprintf(":%d", int(port))
@ -114,16 +117,44 @@ func (s *Server) httpRoute(state *luajit.State) int {
return state.PushError("route: path must be string") return state.PushError("route: path must be string")
} }
if !state.IsFunction(3) { s.funcCounter++
return state.PushError("route: handler must be function") handlerID := s.funcCounter
if state.IsFunction(3) {
// Function handler - store reference
state.PushCopy(3)
funcRef := s.storeFunction(state)
s.handlersMu.Lock()
s.handlers[handlerID] = &HandlerFunc{
funcRef: funcRef,
name: fmt.Sprintf("%s %s", method, path),
isFunction: true,
}
s.handlersMu.Unlock()
} else {
// String handler - compile to bytecode
handlerCode, err := state.SafeToString(3)
if err != nil {
return state.PushError("route: handler must be function or string")
}
bytecode, err := state.CompileBytecode(handlerCode, fmt.Sprintf("handler_%s_%s", method, path))
if err != nil {
return state.PushError("route: failed to compile handler: %s", err.Error())
}
s.handlersMu.Lock()
s.handlers[handlerID] = &HandlerFunc{
bytecode: bytecode,
name: fmt.Sprintf("%s %s", method, path),
isFunction: false,
}
s.handlersMu.Unlock()
} }
// Store function and get reference
state.PushCopy(3)
funcRef := s.storeFunction()
// Add route to router // Add route to router
if err := s.router.AddRoute(strings.ToUpper(method), path, funcRef); err != nil { if err := s.router.AddRoute(strings.ToUpper(method), path, handlerID); err != nil {
return state.PushError("route: failed to add route: %s", err.Error()) return state.PushError("route: failed to add route: %s", err.Error())
} }
@ -131,133 +162,282 @@ func (s *Server) httpRoute(state *luajit.State) int {
return 1 return 1
} }
func (s *Server) storeFunction() int { func (s *Server) storeFunction(state *luajit.State) int {
s.state.GetGlobal("__moonshark_functions") state.GetGlobal("__moonshark_functions")
if s.state.IsNil(-1) { if state.IsNil(-1) {
s.state.Pop(1) state.Pop(1)
s.state.NewTable() state.NewTable()
s.state.PushCopy(-1) state.PushCopy(-1)
s.state.SetGlobal("__moonshark_functions") state.SetGlobal("__moonshark_functions")
} }
s.funcCounter++ s.funcCounter++
s.state.PushNumber(float64(s.funcCounter)) state.PushNumber(float64(s.funcCounter))
s.state.PushCopy(-3) state.PushCopy(-3)
s.state.SetTable(-3) state.SetTable(-3)
s.state.Pop(2) state.Pop(2)
return s.funcCounter return s.funcCounter
} }
func (s *Server) getFunction(ref int) bool { func (s *Server) getFunction(state *luajit.State, ref int) bool {
s.state.GetGlobal("__moonshark_functions") state.GetGlobal("__moonshark_functions")
if s.state.IsNil(-1) { if state.IsNil(-1) {
s.state.Pop(1) state.Pop(1)
return false return false
} }
s.state.PushNumber(float64(ref)) state.PushNumber(float64(ref))
s.state.GetTable(-2) state.GetTable(-2)
isFunc := s.state.IsFunction(-1) isFunc := state.IsFunction(-1)
if !isFunc { if !isFunc {
s.state.Pop(2) state.Pop(2)
return false return false
} }
s.state.Remove(-2) state.Remove(-2)
return true return true
} }
func (s *Server) requestHandler(ctx *fasthttp.RequestCtx) { func (s *Server) fastRequestHandler(ctx *fasthttp.RequestCtx) {
method := string(ctx.Method()) method := string(ctx.Method())
path := string(ctx.Path()) path := string(ctx.Path())
// Look up route in router // Fast route lookup
handlerRef, params, found := s.router.Lookup(method, path) handlerID, params, found := s.router.Lookup(method, path)
if !found { if !found {
ctx.SetStatusCode(404) ctx.SetStatusCode(404)
ctx.SetBodyString("Not Found") ctx.SetBodyString("Not Found")
return return
} }
reqCtx := s.buildRequestContext(ctx, params) // Get compiled handler
reqCtx.Session.AdvanceFlash() s.handlersMu.RLock()
handler := s.handlers[handlerID]
s.handlersMu.RUnlock()
s.stateMu.Lock() if handler == nil {
defer s.stateMu.Unlock()
s.setupRequestEnvironment(reqCtx)
if !s.getFunction(handlerRef) {
ctx.SetStatusCode(500) ctx.SetStatusCode(500)
ctx.SetBodyString("Handler not found") ctx.SetBodyString("Handler not found")
return return
} }
if err := s.state.PushValue(s.requestToTable(reqCtx)); err != nil { // Lock state for execution
ctx.SetStatusCode(500) s.stateMu.Lock()
ctx.SetBodyString("Failed to create request object") defer s.stateMu.Unlock()
return
}
if err := s.state.Call(1, 1); err != nil { // Setup request context
ctx.SetStatusCode(500) reqCtx := &RequestContext{
ctx.SetBodyString(fmt.Sprintf("Handler error: %v", err)) ctx: ctx,
return params: params,
session: s.sessions.GetSessionFromRequest(ctx),
} }
reqCtx.session.AdvanceFlash()
var responseBody string var responseBody string
if s.state.GetTop() > 0 && !s.state.IsNil(-1) {
responseBody = s.state.ToString(-1) if handler.isFunction {
// Function handler - use traditional approach
s.setupFunctionEnvironment(s.state, reqCtx)
if !s.getFunction(s.state, handler.funcRef) {
ctx.SetStatusCode(500)
ctx.SetBodyString("Function handler not found")
return
}
// Push request object
if err := s.state.PushValue(s.requestToTable(reqCtx)); err != nil {
ctx.SetStatusCode(500)
ctx.SetBodyString("Failed to create request object")
return
}
if err := s.state.Call(1, 1); err != nil {
ctx.SetStatusCode(500)
ctx.SetBodyString(fmt.Sprintf("Handler error: %v", err))
return
}
if s.state.GetTop() > 0 && !s.state.IsNil(-1) {
responseBody = s.state.ToString(-1)
}
s.state.Pop(1)
} else {
// Bytecode handler - use fast approach
s.setupFastEnvironment(s.state, reqCtx)
if err := s.state.LoadAndRunBytecode(handler.bytecode, handler.name); err != nil {
ctx.SetStatusCode(500)
ctx.SetBodyString(fmt.Sprintf("Handler error: %v", err))
return
}
if s.state.GetTop() > 0 && !s.state.IsNil(-1) {
responseBody = s.state.ToString(-1)
}
s.state.Pop(1) s.state.Pop(1)
} }
s.updateSessionFromLua(reqCtx.Session) // Apply response
s.applyResponse(ctx, responseBody) s.applyResponse(ctx, s.state, responseBody, reqCtx.session)
s.sessions.ApplySessionCookie(ctx, reqCtx.Session) s.sessions.ApplySessionCookie(ctx, reqCtx.session)
// Clean up state
s.state.SetTop(0)
} }
func (s *Server) setupRequestEnvironment(reqCtx *RequestContext) { func (s *Server) setupFunctionEnvironment(state *luajit.State, reqCtx *RequestContext) {
s.state.PushValue(s.requestToTable(reqCtx)) // Set up response globals for function handlers
s.state.SetGlobal("__request") state.NewTable()
state.PushNumber(200)
state.SetField(-2, "status")
state.NewTable()
state.SetField(-2, "headers")
state.NewTable()
state.SetField(-2, "cookies")
state.SetGlobal("__response")
s.state.PushValue(s.sessionToTable(reqCtx.Session)) // Session data
s.state.SetGlobal("__session") if !reqCtx.session.IsEmpty() {
state.PushValue(reqCtx.session.GetAll())
state.SetGlobal("__session")
}
}
s.state.NewTable() func (s *Server) setupFastEnvironment(state *luajit.State, reqCtx *RequestContext) {
s.state.SetGlobal("__response") // Request basics as globals for fast access
state.PushString(string(reqCtx.ctx.Method()))
state.SetGlobal("REQUEST_METHOD")
state.PushString(string(reqCtx.ctx.Path()))
state.SetGlobal("REQUEST_PATH")
// Parameters
if reqCtx.params != nil && len(reqCtx.params.Keys) > 0 {
paramMap := make(map[string]string, len(reqCtx.params.Keys))
for i, key := range reqCtx.params.Keys {
if i < len(reqCtx.params.Values) {
paramMap[key] = reqCtx.params.Values[i]
}
}
state.PushValue(paramMap)
state.SetGlobal("PARAMS")
}
// Query parameters
queryMap := make(map[string]string)
reqCtx.ctx.QueryArgs().VisitAll(func(key, value []byte) {
queryMap[string(key)] = string(value)
})
if len(queryMap) > 0 {
state.PushValue(queryMap)
state.SetGlobal("QUERY")
}
// Headers
headerMap := make(map[string]string)
reqCtx.ctx.Request.Header.VisitAll(func(key, value []byte) {
headerMap[string(key)] = string(value)
})
state.PushValue(headerMap)
state.SetGlobal("HEADERS")
// Cookies
cookieMap := make(map[string]string)
reqCtx.ctx.Request.Header.VisitAllCookie(func(key, value []byte) {
cookieMap[string(key)] = string(value)
})
if len(cookieMap) > 0 {
state.PushValue(cookieMap)
state.SetGlobal("COOKIES")
}
// Form data
if reqCtx.ctx.IsPost() || reqCtx.ctx.IsPut() || reqCtx.ctx.IsPatch() {
form := s.parseForm(reqCtx.ctx)
if len(form) > 0 {
state.PushValue(form)
state.SetGlobal("FORM")
}
}
// Session data
if !reqCtx.session.IsEmpty() {
state.PushValue(reqCtx.session.GetAll())
state.SetGlobal("session_data")
}
// CSRF token
if csrfToken := s.generateCSRFToken(); csrfToken != "" {
state.PushString(csrfToken)
state.SetGlobal("CSRF_TOKEN")
}
// JSON encode fallback
state.RegisterGoFunction("json_encode_fallback", func(state *luajit.State) int {
val, _ := state.ToValue(1)
if b, err := json.Marshal(val); err == nil {
state.PushString(string(b))
} else {
state.PushString("null")
}
return 1
})
} }
func (s *Server) requestToTable(reqCtx *RequestContext) map[string]any { func (s *Server) requestToTable(reqCtx *RequestContext) map[string]any {
return map[string]any{ req := map[string]any{
"method": reqCtx.Method, "method": string(reqCtx.ctx.Method()),
"path": reqCtx.Path, "path": string(reqCtx.ctx.Path()),
"headers": reqCtx.Headers, "headers": make(map[string]string),
"query": reqCtx.Query, "query": make(map[string]string),
"form": reqCtx.Form, "cookies": make(map[string]string),
"cookies": reqCtx.Cookies, "body": string(reqCtx.ctx.PostBody()),
"body": reqCtx.Body,
"params": reqCtx.Params,
} }
// Headers
headers := req["headers"].(map[string]string)
reqCtx.ctx.Request.Header.VisitAll(func(key, value []byte) {
headers[string(key)] = string(value)
})
// Cookies
cookies := req["cookies"].(map[string]string)
reqCtx.ctx.Request.Header.VisitAllCookie(func(key, value []byte) {
cookies[string(key)] = string(value)
})
// Query
query := req["query"].(map[string]string)
reqCtx.ctx.QueryArgs().VisitAll(func(key, value []byte) {
query[string(key)] = string(value)
})
// Params
if reqCtx.params != nil && len(reqCtx.params.Keys) > 0 {
params := make(map[string]string, len(reqCtx.params.Keys))
for i, key := range reqCtx.params.Keys {
if i < len(reqCtx.params.Values) {
params[key] = reqCtx.params.Values[i]
}
}
req["params"] = params
}
// Form
if reqCtx.ctx.IsPost() || reqCtx.ctx.IsPut() || reqCtx.ctx.IsPatch() {
req["form"] = s.parseForm(reqCtx.ctx)
}
return req
} }
func (s *Server) sessionToTable(session *sessions.Session) map[string]any { func (s *Server) applyResponse(ctx *fasthttp.RequestCtx, state *luajit.State, body string, session *sessions.Session) {
return map[string]any{ // Update session from Lua
"id": session.ID, state.GetGlobal("session_data")
"data": session.GetAll(), if state.IsTable(-1) {
} if data, err := state.ToTable(-1); err == nil {
}
func (s *Server) updateSessionFromLua(session *sessions.Session) {
s.state.GetGlobal("__session")
if s.state.IsNil(-1) {
s.state.Pop(1)
return
}
s.state.GetField(-1, "data")
if s.state.IsTable(-1) {
if data, err := s.state.ToTable(-1); err == nil {
if dataMap, ok := data.(map[string]any); ok { if dataMap, ok := data.(map[string]any); ok {
session.Clear() session.Clear()
for k, v := range dataMap { for k, v := range dataMap {
@ -266,55 +446,59 @@ func (s *Server) updateSessionFromLua(session *sessions.Session) {
} }
} }
} }
s.state.Pop(2) state.Pop(1)
}
func (s *Server) applyResponse(ctx *fasthttp.RequestCtx, body string) { // Check for response table (function handlers) or response global (fast handlers)
s.state.GetGlobal("__response") state.GetGlobal("__response")
if s.state.IsNil(-1) { if state.IsNil(-1) {
s.state.Pop(1) state.Pop(1)
if body != "" { state.GetGlobal("response")
ctx.SetBodyString(body) if state.IsNil(-1) {
state.Pop(1)
if body != "" {
ctx.SetBodyString(body)
}
return
} }
return
} }
s.state.GetField(-1, "status") // Status
if s.state.IsNumber(-1) { if status := state.GetFieldNumber(-1, "status", 200); status != 200 {
ctx.SetStatusCode(int(s.state.ToNumber(-1))) ctx.SetStatusCode(int(status))
} }
s.state.Pop(1)
s.state.GetField(-1, "headers") // Headers
if s.state.IsTable(-1) { state.GetField(-1, "headers")
s.state.ForEachTableKV(-1, func(key, value string) bool { if state.IsTable(-1) {
state.ForEachTableKV(-1, func(key, value string) bool {
ctx.Response.Header.Set(key, value) ctx.Response.Header.Set(key, value)
return true return true
}) })
} }
s.state.Pop(1) state.Pop(1)
s.state.GetField(-1, "cookies") // Cookies
if s.state.IsTable(-1) { state.GetField(-1, "cookies")
s.applyCookies(ctx) if state.IsTable(-1) {
s.applyCookies(ctx, state)
} }
s.state.Pop(1) state.Pop(1)
s.state.Pop(1) state.Pop(1)
if body != "" { if body != "" {
ctx.SetBodyString(body) ctx.SetBodyString(body)
} }
} }
func (s *Server) applyCookies(ctx *fasthttp.RequestCtx) { func (s *Server) applyCookies(ctx *fasthttp.RequestCtx, state *luajit.State) {
s.state.ForEachArray(-1, func(i int, state *luajit.State) bool { state.ForEachArray(-1, func(i int, st *luajit.State) bool {
if !state.IsTable(-1) { if !st.IsTable(-1) {
return true return true
} }
name := state.GetFieldString(-1, "name", "") name := st.GetFieldString(-1, "name", "")
value := state.GetFieldString(-1, "value", "") value := st.GetFieldString(-1, "value", "")
if name == "" { if name == "" {
return true return true
} }
@ -324,22 +508,29 @@ func (s *Server) applyCookies(ctx *fasthttp.RequestCtx) {
cookie.SetKey(name) cookie.SetKey(name)
cookie.SetValue(value) cookie.SetValue(value)
cookie.SetPath(state.GetFieldString(-1, "path", "/"))
if domain := state.GetFieldString(-1, "domain", ""); domain != "" { if options, ok := st.GetFieldTable(-1, "options"); ok {
cookie.SetDomain(domain) if optMap, ok := options.(map[string]any); ok {
} if path, ok := optMap["path"].(string); ok {
cookie.SetPath(path)
if state.GetFieldBool(-1, "secure", false) { } else {
cookie.SetSecure(true) cookie.SetPath("/")
} }
if domain, ok := optMap["domain"].(string); ok {
if state.GetFieldBool(-1, "http_only", true) { cookie.SetDomain(domain)
cookie.SetHTTPOnly(true) }
} if secure, ok := optMap["secure"].(bool); ok && secure {
cookie.SetSecure(true)
if maxAge := state.GetFieldNumber(-1, "max_age", 0); maxAge > 0 { }
cookie.SetExpire(time.Now().Add(time.Duration(maxAge) * time.Second)) if httpOnly, ok := optMap["http_only"].(bool); ok {
cookie.SetHTTPOnly(httpOnly)
} else {
cookie.SetHTTPOnly(true)
}
if maxAge, ok := optMap["max_age"].(int); ok && maxAge > 0 {
cookie.SetExpire(time.Now().Add(time.Duration(maxAge) * time.Second))
}
}
} }
ctx.Response.Header.SetCookie(cookie) ctx.Response.Header.SetCookie(cookie)
@ -347,42 +538,6 @@ func (s *Server) applyCookies(ctx *fasthttp.RequestCtx) {
}) })
} }
func (s *Server) buildRequestContext(ctx *fasthttp.RequestCtx, params *router.Params) *RequestContext {
reqCtx := &RequestContext{
Method: string(ctx.Method()),
Path: string(ctx.Path()),
Headers: make(map[string]string),
Query: make(map[string]string),
Cookies: make(map[string]string),
Body: string(ctx.PostBody()),
Params: make(map[string]string),
}
ctx.Request.Header.VisitAll(func(key, value []byte) {
reqCtx.Headers[string(key)] = string(value)
})
ctx.Request.Header.VisitAllCookie(func(key, value []byte) {
reqCtx.Cookies[string(key)] = string(value)
})
ctx.QueryArgs().VisitAll(func(key, value []byte) {
reqCtx.Query[string(key)] = string(value)
})
// Convert router params to map
for i, key := range params.Keys {
if i < len(params.Values) {
reqCtx.Params[key] = params.Values[i]
}
}
reqCtx.Form = s.parseForm(ctx)
reqCtx.Session = s.sessions.GetSessionFromRequest(ctx)
return reqCtx
}
func (s *Server) parseForm(ctx *fasthttp.RequestCtx) map[string]any { func (s *Server) parseForm(ctx *fasthttp.RequestCtx) map[string]any {
contentType := string(ctx.Request.Header.ContentType()) contentType := string(ctx.Request.Header.ContentType())
form := make(map[string]any) form := make(map[string]any)
@ -440,317 +595,3 @@ func (s *Server) generateCSRFToken() string {
rand.Read(bytes) rand.Read(bytes)
return base64.URLEncoding.EncodeToString(bytes) return base64.URLEncoding.EncodeToString(bytes)
} }
// Lua function implementations
func httpSetStatus(state *luajit.State) int {
code, _ := state.SafeToNumber(1)
state.GetGlobal("__response")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.SetGlobal("__response")
state.GetGlobal("__response")
}
state.PushNumber(code)
state.SetField(-2, "status")
state.Pop(1)
return 0
}
func httpSetHeader(state *luajit.State) int {
name, _ := state.SafeToString(1)
value, _ := state.SafeToString(2)
state.GetGlobal("__response")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.SetGlobal("__response")
state.GetGlobal("__response")
}
state.GetField(-1, "headers")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.PushCopy(-1)
state.SetField(-3, "headers")
}
state.PushString(value)
state.SetField(-2, name)
state.Pop(2)
return 0
}
func httpRedirect(state *luajit.State) int {
url, _ := state.SafeToString(1)
status := 302.0
if state.GetTop() >= 2 {
status, _ = state.SafeToNumber(2)
}
state.GetGlobal("__response")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.SetGlobal("__response")
state.GetGlobal("__response")
}
state.PushNumber(status)
state.SetField(-2, "status")
state.GetField(-1, "headers")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.PushCopy(-1)
state.SetField(-3, "headers")
}
state.PushString(url)
state.SetField(-2, "Location")
state.Pop(2)
return 0
}
func (s *Server) sessionGet(state *luajit.State) int {
key, _ := state.SafeToString(1)
state.GetGlobal("__session")
if state.IsNil(-1) {
state.Pop(1)
state.PushNil()
return 1
}
state.GetField(-1, "data")
if state.IsNil(-1) {
state.Pop(2)
state.PushNil()
return 1
}
state.GetField(-1, key)
return 1
}
func (s *Server) sessionSet(state *luajit.State) int {
key, _ := state.SafeToString(1)
state.GetGlobal("__session")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.SetGlobal("__session")
state.GetGlobal("__session")
}
state.GetField(-1, "data")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.PushCopy(-1)
state.SetField(-3, "data")
}
value, err := state.ToValue(2)
if err == nil {
state.PushValue(value)
state.SetField(-2, key)
}
state.Pop(2)
return 0
}
func (s *Server) sessionFlash(state *luajit.State) int {
key, _ := state.SafeToString(1)
state.GetGlobal("__session")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.SetGlobal("__session")
state.GetGlobal("__session")
}
state.GetField(-1, "flash")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.PushCopy(-1)
state.SetField(-3, "flash")
}
value, err := state.ToValue(2)
if err == nil {
state.PushValue(value)
state.SetField(-2, key)
}
state.Pop(2)
return 0
}
func (s *Server) sessionGetFlash(state *luajit.State) int {
key, _ := state.SafeToString(1)
state.GetGlobal("__session")
if state.IsNil(-1) {
state.Pop(1)
state.PushNil()
return 1
}
state.GetField(-1, "flash")
if state.IsNil(-1) {
state.Pop(2)
state.PushNil()
return 1
}
state.GetField(-1, key)
return 1
}
func cookieSet(state *luajit.State) int {
name, _ := state.SafeToString(1)
value, _ := state.SafeToString(2)
maxAge := 0
path := "/"
domain := ""
secure := false
httpOnly := true
if state.GetTop() >= 3 && state.IsTable(3) {
maxAge = int(state.GetFieldNumber(3, "max_age", 0))
path = state.GetFieldString(3, "path", "/")
domain = state.GetFieldString(3, "domain", "")
secure = state.GetFieldBool(3, "secure", false)
httpOnly = state.GetFieldBool(3, "http_only", true)
}
state.GetGlobal("__response")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.SetGlobal("__response")
state.GetGlobal("__response")
}
state.GetField(-1, "cookies")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.PushCopy(-1)
state.SetField(-3, "cookies")
}
cookieData := map[string]any{
"name": name,
"value": value,
"path": path,
"secure": secure,
"http_only": httpOnly,
}
if domain != "" {
cookieData["domain"] = domain
}
if maxAge > 0 {
cookieData["max_age"] = maxAge
}
state.PushValue(cookieData)
length := globalServer.getTableLength(-2)
state.PushNumber(float64(length + 1))
state.PushCopy(-2)
state.SetTable(-4)
state.Pop(3)
return 0
}
func cookieGet(state *luajit.State) int {
name, _ := state.SafeToString(1)
state.GetGlobal("__request")
if state.IsNil(-1) {
state.Pop(1)
state.PushNil()
return 1
}
state.GetField(-1, "cookies")
if state.IsNil(-1) {
state.Pop(2)
state.PushNil()
return 1
}
state.GetField(-1, name)
return 1
}
func (s *Server) csrfGenerate(state *luajit.State) int {
token := s.generateCSRFToken()
state.GetGlobal("__session")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.SetGlobal("__session")
state.GetGlobal("__session")
}
state.GetField(-1, "data")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.PushCopy(-1)
state.SetField(-3, "data")
}
state.PushString(token)
state.SetField(-2, "_csrf_token")
state.Pop(2)
state.PushString(token)
return 1
}
func (s *Server) csrfValidate(state *luajit.State) int {
state.GetGlobal("__session")
if state.IsNil(-1) {
state.Pop(1)
state.PushBoolean(false)
return 1
}
sessionToken := state.GetFieldString(-1, "data._csrf_token", "")
state.Pop(1)
state.GetGlobal("__request")
if state.IsNil(-1) {
state.Pop(1)
state.PushBoolean(false)
return 1
}
requestToken := state.GetFieldString(-1, "form._csrf_token", "")
state.Pop(1)
state.PushBoolean(sessionToken != "" && sessionToken == requestToken)
return 1
}
func (s *Server) getTableLength(index int) int {
length := 0
s.state.PushNil()
for s.state.Next(index - 1) {
length++
s.state.Pop(1)
}
return length
}

View File

@ -1,3 +1,7 @@
-- Fast response handling
local response = {status = 200, headers = {}, cookies = {}}
local session_data = {}
http = {} http = {}
function http.listen(port) function http.listen(port)
@ -9,16 +13,11 @@ function http.route(method, path, handler)
end end
function http.status(code) function http.status(code)
return __http_set_status(code) response.status = code
end end
function http.header(name, value) function http.header(k, v)
return __http_set_header(name, value) response.headers[k] = v
end
function http.redirect(url, status)
__http_redirect(url, status or 302)
coroutine.yield() -- Exit handler
end end
function http.json(data) function http.json(data)
@ -36,80 +35,115 @@ function http.text(content)
return content return content
end end
function http.redirect(url, code)
response.status = code or 302
response.headers["Location"] = url
coroutine.yield()
end
-- Session functions -- Session functions
session = {} session = {}
function session.get(key) function session.get(key)
return __session_get(key) return session_data[key]
end end
function session.set(key, value) function session.set(key, val)
return __session_set(key, value) session_data[key] = val
end end
function session.flash(key, value) function session.flash(key, val)
return __session_flash(key, value) session_data["_flash_" .. key] = val
end end
function session.get_flash(key) function session.get_flash(key)
return __session_get_flash(key) local val = session_data["_flash_" .. key]
session_data["_flash_" .. key] = nil
return val
end end
-- Cookie functions -- Cookie functions
cookie = {} cookie = {}
function cookie.set(name, value, options) function cookie.get(name)
return __cookie_set(name, value, options) return COOKIES and COOKIES[name]
end end
function cookie.get(name) function cookie.set(name, value, options)
return __cookie_get(name) response.cookies[#response.cookies + 1] = {
name = name,
value = value,
options = options or {}
}
end end
-- CSRF functions -- CSRF functions
csrf = {} csrf = {}
function csrf.generate() function csrf.generate()
return __csrf_generate() local token = CSRF_TOKEN or ""
session.set("_csrf_token", token)
return token
end end
function csrf.validate() function csrf.validate()
return __csrf_validate() local session_token = session.get("_csrf_token")
local form_token = FORM and FORM._csrf_token
return session_token and session_token == form_token
end end
function csrf.field() function csrf.field()
local token = csrf.generate() return '<input type="hidden" name="_csrf_token" value="' .. csrf.generate() .. '" />'
return string.format('<input type="hidden" name="_csrf_token" value="%s" />', token)
end end
-- Fast JSON encoding
json = {
encode = function(data)
if type(data) == "string" then
return '"' .. data .. '"'
elseif type(data) == "number" then
return tostring(data)
elseif type(data) == "boolean" then
return data and "true" or "false"
elseif data == nil then
return "null"
elseif type(data) == "table" then
-- Check if it's an array
local isArray = true
local n = 0
for k, v in pairs(data) do
n = n + 1
if type(k) ~= "number" or k ~= n then
isArray = false
break
end
end
if isArray then
local result = "["
for i = 1, n do
if i > 1 then result = result .. "," end
result = result .. json.encode(data[i])
end
return result .. "]"
else
local result = "{"
local first = true
for k, v in pairs(data) do
if not first then result = result .. "," end
result = result .. '"' .. tostring(k) .. '":' .. json.encode(v)
first = false
end
return result .. "}"
end
else
return json_encode_fallback(data)
end
end
}
-- Helper functions -- Helper functions
function redirect_with_flash(url, type, message) function redirect_with_flash(url, type, message)
session.flash(type, message) session.flash(type, message)
http.redirect(url) http.redirect(url)
end end
-- JSON encoding/decoding placeholder
json = {
encode = function(data)
-- Simplified JSON encoding
if type(data) == "table" then
local result = "{"
local first = true
for k, v in pairs(data) do
if not first then result = result .. "," end
result = result .. '"' .. tostring(k) .. '":' .. json.encode(v)
first = false
end
return result .. "}"
elseif type(data) == "string" then
return '"' .. data .. '"'
else
return tostring(data)
end
end,
decode = function(str)
-- Simplified JSON decoding - you'd want a proper implementation
return {}
end
}

View File

@ -1,25 +1,24 @@
package router package router
import ( import (
"errors" "fmt"
"strings"
"sync"
) )
// node represents a node in the radix trie // Handler function that takes parameters as strings
type Handler func(params []string)
type node struct { type node struct {
segment string segment string
handler int // Lua function reference handlerID int
children []*node children []*node
isDynamic bool // :param isDynamic bool
isWildcard bool // *param isWildcard bool
paramName string maxParams uint8
} }
// Router is a string-based HTTP router with efficient lookup
type Router struct { type Router struct {
get, post, put, patch, delete *node get, post, put, patch, delete *node
mu sync.RWMutex paramsBuffer []string
} }
// Params holds URL parameters // Params holds URL parameters
@ -41,15 +40,15 @@ func (p *Params) Get(name string) string {
// New creates a new Router instance // New creates a new Router instance
func New() *Router { func New() *Router {
return &Router{ return &Router{
get: &node{}, get: &node{},
post: &node{}, post: &node{},
put: &node{}, put: &node{},
patch: &node{}, patch: &node{},
delete: &node{}, delete: &node{},
paramsBuffer: make([]string, 64),
} }
} }
// methodNode returns the root node for a method
func (r *Router) methodNode(method string) *node { func (r *Router) methodNode(method string) *node {
switch method { switch method {
case "GET": case "GET":
@ -67,47 +66,71 @@ func (r *Router) methodNode(method string) *node {
} }
} }
// AddRoute adds a new route with handler reference // AddRoute adds a route with handler ID (for compatibility)
func (r *Router) AddRoute(method, path string, handlerRef int) error { func (r *Router) AddRoute(method, path string, handlerID int) error {
r.mu.Lock()
defer r.mu.Unlock()
root := r.methodNode(method) root := r.methodNode(method)
if root == nil { if root == nil {
return errors.New("unsupported HTTP method") return fmt.Errorf("unsupported method: %s", method)
} }
// Create a handler that stores the ID
h := func(params []string) {
// This is a placeholder - the actual execution happens in HTTP module
}
return r.addRoute(root, path, h, handlerID)
}
// readSegment extracts the next path segment
func readSegment(path string, start int) (segment string, end int, hasMore bool) {
if start >= len(path) {
return "", start, false
}
if path[start] == '/' {
start++
}
if start >= len(path) {
return "", start, false
}
end = start
for end < len(path) && path[end] != '/' {
end++
}
return path[start:end], end, end < len(path)
}
// addRoute adds a new route to the trie
func (r *Router) addRoute(root *node, path string, h Handler, handlerID int) error {
if path == "/" { if path == "/" {
root.handler = handlerRef root.handlerID = handlerID
return nil return nil
} }
return r.addRoute(root, path, handlerRef)
}
// addRoute adds a route to the trie
func (r *Router) addRoute(root *node, path string, handlerRef int) error {
segments := r.parseSegments(path)
current := root current := root
pos := 0
lastWC := false
count := uint8(0)
for _, seg := range segments { for {
isDyn := strings.HasPrefix(seg, ":") seg, newPos, more := readSegment(path, pos)
isWC := strings.HasPrefix(seg, "*") if seg == "" {
break
if isWC && seg != segments[len(segments)-1] {
return errors.New("wildcard must be the last segment")
} }
paramName := "" isDyn := len(seg) > 1 && seg[0] == ':'
if isDyn { isWC := len(seg) > 0 && seg[0] == '*'
paramName = seg[1:]
seg = ":" if isWC {
} else if isWC { if lastWC || more {
paramName = seg[1:] return fmt.Errorf("wildcard must be the last segment in the path")
seg = "*" }
lastWC = true
}
if isDyn || isWC {
count++
} }
// Find or create child
var child *node var child *node
for _, c := range current.children { for _, c := range current.children {
if c.segment == seg { if c.segment == seg {
@ -117,144 +140,102 @@ func (r *Router) addRoute(root *node, path string, handlerRef int) error {
} }
if child == nil { if child == nil {
child = &node{ child = &node{segment: seg, isDynamic: isDyn, isWildcard: isWC}
segment: seg,
isDynamic: isDyn,
isWildcard: isWC,
paramName: paramName,
}
current.children = append(current.children, child) current.children = append(current.children, child)
} }
if child.maxParams < count {
child.maxParams = count
}
current = child current = child
pos = newPos
} }
current.handler = handlerRef current.handlerID = handlerID
return nil return nil
} }
// parseSegments splits path into segments // Lookup finds a handler matching method and path
func (r *Router) parseSegments(path string) []string {
segments := strings.Split(strings.Trim(path, "/"), "/")
var result []string
for _, seg := range segments {
if seg != "" {
result = append(result, seg)
}
}
return result
}
// Lookup finds handler and parameters for a method and path
func (r *Router) Lookup(method, path string) (int, *Params, bool) { func (r *Router) Lookup(method, path string) (int, *Params, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
root := r.methodNode(method) root := r.methodNode(method)
if root == nil { if root == nil {
return 0, nil, false return 0, nil, false
} }
if path == "/" { if path == "/" {
if root.handler != 0 { if root.handlerID != 0 {
return root.handler, &Params{}, true return root.handlerID, &Params{}, true
} }
return 0, nil, false return 0, nil, false
} }
segments := r.parseSegments(path) buffer := r.paramsBuffer
handler, params := r.match(root, segments, 0) if cap(buffer) < int(root.maxParams) {
if handler == 0 { buffer = make([]string, root.maxParams)
r.paramsBuffer = buffer
}
buffer = buffer[:0]
handlerID, paramCount, paramKeys, found := r.match(root, path, 0, &buffer)
if !found {
return 0, nil, false return 0, nil, false
} }
return handler, params, true params := &Params{
} Keys: paramKeys,
Values: buffer[:paramCount],
// match traverses the trie to find handler
func (r *Router) match(current *node, segments []string, index int) (int, *Params) {
if index >= len(segments) {
if current.handler != 0 {
return current.handler, &Params{}
}
return 0, nil
} }
segment := segments[index] return handlerID, params, true
}
// Check exact match first // match traverses the trie to find a handler
for _, child := range current.children { func (r *Router) match(current *node, path string, start int, params *[]string) (int, int, []string, bool) {
if child.segment == segment { paramCount := 0
handler, params := r.match(child, segments, index+1) var paramKeys []string
if handler != 0 {
return handler, params // Check wildcards first
for _, c := range current.children {
if c.isWildcard {
rem := path[start:]
if len(rem) > 0 && rem[0] == '/' {
rem = rem[1:]
}
*params = append(*params, rem)
// Extract param name from *name format
paramName := c.segment[1:] // Remove the * prefix
paramKeys = append(paramKeys, paramName)
return c.handlerID, 1, paramKeys, c.handlerID != 0
}
}
seg, pos, more := readSegment(path, start)
if seg == "" {
return current.handlerID, 0, paramKeys, current.handlerID != 0
}
for _, c := range current.children {
if c.segment == seg || c.isDynamic {
if c.isDynamic {
*params = append(*params, seg)
// Extract param name from :name format
paramName := c.segment[1:] // Remove the : prefix
paramKeys = append(paramKeys, paramName)
paramCount++
}
if !more {
return c.handlerID, paramCount, paramKeys, c.handlerID != 0
}
handlerID, nestedCount, nestedKeys, ok := r.match(c, path, pos, params)
if ok {
allKeys := append(paramKeys, nestedKeys...)
return handlerID, paramCount + nestedCount, allKeys, true
} }
} }
} }
// Check dynamic match second return 0, 0, nil, false
for _, child := range current.children {
if child.isDynamic {
handler, params := r.match(child, segments, index+1)
if handler != 0 {
// Prepend this parameter
newParams := &Params{
Keys: append([]string{child.paramName}, params.Keys...),
Values: append([]string{segment}, params.Values...),
}
return handler, newParams
}
}
}
// Check wildcard last (catches everything remaining)
for _, child := range current.children {
if child.isWildcard {
remaining := strings.Join(segments[index:], "/")
return child.handler, &Params{
Keys: []string{child.paramName},
Values: []string{remaining},
}
}
}
return 0, nil
}
// RemoveRoute removes a route
func (r *Router) RemoveRoute(method, path string) {
r.mu.Lock()
defer r.mu.Unlock()
root := r.methodNode(method)
if root == nil {
return
}
if path == "/" {
root.handler = 0
return
}
segments := r.parseSegments(path)
r.removeRoute(root, segments, 0)
}
// removeRoute removes a route from the trie
func (r *Router) removeRoute(current *node, segments []string, index int) {
if index >= len(segments) {
current.handler = 0
return
}
segment := segments[index]
for _, child := range current.children {
if child.segment == segment ||
(child.isDynamic && strings.HasPrefix(segment, ":")) ||
(child.isWildcard && strings.HasPrefix(segment, "*")) {
r.removeRoute(child, segments, index+1)
break
}
}
} }