next pass
This commit is contained in:
parent
bb06e2431d
commit
843e318e01
843
http/http.go
843
http/http.go
@ -21,25 +21,33 @@ import (
|
|||||||
//go:embed http.lua
|
//go:embed http.lua
|
||||||
var httpLuaCode string
|
var httpLuaCode string
|
||||||
|
|
||||||
|
// HandlerFunc represents a Lua handler
|
||||||
|
type HandlerFunc struct {
|
||||||
|
bytecode []byte
|
||||||
|
funcRef int
|
||||||
|
name string
|
||||||
|
isFunction bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server with single state for function handler compatibility
|
||||||
type Server struct {
|
type Server struct {
|
||||||
server *fasthttp.Server
|
server *fasthttp.Server
|
||||||
router *router.Router
|
router *router.Router
|
||||||
sessions *sessions.SessionManager
|
sessions *sessions.SessionManager
|
||||||
state *luajit.State
|
state *luajit.State
|
||||||
stateMu sync.Mutex
|
stateMu sync.Mutex
|
||||||
|
handlers map[int]*HandlerFunc
|
||||||
|
handlersMu sync.RWMutex
|
||||||
funcCounter int
|
funcCounter int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RequestContext with lazy parsing
|
||||||
type RequestContext struct {
|
type RequestContext struct {
|
||||||
Method string
|
ctx *fasthttp.RequestCtx
|
||||||
Path string
|
params *router.Params
|
||||||
Headers map[string]string
|
session *sessions.Session
|
||||||
Query map[string]string
|
parsedForm map[string]any
|
||||||
Form map[string]any
|
formOnce sync.Once
|
||||||
Cookies map[string]string
|
|
||||||
Session *sessions.Session
|
|
||||||
Body string
|
|
||||||
Params map[string]string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var globalServer *Server
|
var globalServer *Server
|
||||||
@ -49,6 +57,7 @@ func NewServer(state *luajit.State) *Server {
|
|||||||
router: router.New(),
|
router: router.New(),
|
||||||
sessions: sessions.NewSessionManager(10000),
|
sessions: sessions.NewSessionManager(10000),
|
||||||
state: state,
|
state: state,
|
||||||
|
handlers: make(map[int]*HandlerFunc),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,19 +65,8 @@ func RegisterHTTPFunctions(L *luajit.State) error {
|
|||||||
globalServer = NewServer(L)
|
globalServer = NewServer(L)
|
||||||
|
|
||||||
functions := map[string]luajit.GoFunction{
|
functions := map[string]luajit.GoFunction{
|
||||||
"__http_listen": globalServer.httpListen,
|
"__http_listen": globalServer.httpListen,
|
||||||
"__http_route": globalServer.httpRoute,
|
"__http_route": globalServer.httpRoute,
|
||||||
"__http_set_status": httpSetStatus,
|
|
||||||
"__http_set_header": httpSetHeader,
|
|
||||||
"__http_redirect": httpRedirect,
|
|
||||||
"__session_get": globalServer.sessionGet,
|
|
||||||
"__session_set": globalServer.sessionSet,
|
|
||||||
"__session_flash": globalServer.sessionFlash,
|
|
||||||
"__session_get_flash": globalServer.sessionGetFlash,
|
|
||||||
"__cookie_set": cookieSet,
|
|
||||||
"__cookie_get": cookieGet,
|
|
||||||
"__csrf_generate": globalServer.csrfGenerate,
|
|
||||||
"__csrf_validate": globalServer.csrfValidate,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, fn := range functions {
|
for name, fn := range functions {
|
||||||
@ -87,8 +85,13 @@ func (s *Server) httpListen(state *luajit.State) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.server = &fasthttp.Server{
|
s.server = &fasthttp.Server{
|
||||||
Handler: s.requestHandler,
|
Handler: s.fastRequestHandler,
|
||||||
Name: "Moonshark/1.0",
|
Name: "Moonshark/2.0",
|
||||||
|
Concurrency: 256 * 1024,
|
||||||
|
ReadBufferSize: 4096,
|
||||||
|
WriteBufferSize: 4096,
|
||||||
|
ReduceMemoryUsage: true,
|
||||||
|
NoDefaultServerHeader: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := fmt.Sprintf(":%d", int(port))
|
addr := fmt.Sprintf(":%d", int(port))
|
||||||
@ -114,16 +117,44 @@ func (s *Server) httpRoute(state *luajit.State) int {
|
|||||||
return state.PushError("route: path must be string")
|
return state.PushError("route: path must be string")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !state.IsFunction(3) {
|
s.funcCounter++
|
||||||
return state.PushError("route: handler must be function")
|
handlerID := s.funcCounter
|
||||||
|
|
||||||
|
if state.IsFunction(3) {
|
||||||
|
// Function handler - store reference
|
||||||
|
state.PushCopy(3)
|
||||||
|
funcRef := s.storeFunction(state)
|
||||||
|
|
||||||
|
s.handlersMu.Lock()
|
||||||
|
s.handlers[handlerID] = &HandlerFunc{
|
||||||
|
funcRef: funcRef,
|
||||||
|
name: fmt.Sprintf("%s %s", method, path),
|
||||||
|
isFunction: true,
|
||||||
|
}
|
||||||
|
s.handlersMu.Unlock()
|
||||||
|
} else {
|
||||||
|
// String handler - compile to bytecode
|
||||||
|
handlerCode, err := state.SafeToString(3)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("route: handler must be function or string")
|
||||||
|
}
|
||||||
|
|
||||||
|
bytecode, err := state.CompileBytecode(handlerCode, fmt.Sprintf("handler_%s_%s", method, path))
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("route: failed to compile handler: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
s.handlersMu.Lock()
|
||||||
|
s.handlers[handlerID] = &HandlerFunc{
|
||||||
|
bytecode: bytecode,
|
||||||
|
name: fmt.Sprintf("%s %s", method, path),
|
||||||
|
isFunction: false,
|
||||||
|
}
|
||||||
|
s.handlersMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store function and get reference
|
|
||||||
state.PushCopy(3)
|
|
||||||
funcRef := s.storeFunction()
|
|
||||||
|
|
||||||
// Add route to router
|
// Add route to router
|
||||||
if err := s.router.AddRoute(strings.ToUpper(method), path, funcRef); err != nil {
|
if err := s.router.AddRoute(strings.ToUpper(method), path, handlerID); err != nil {
|
||||||
return state.PushError("route: failed to add route: %s", err.Error())
|
return state.PushError("route: failed to add route: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,133 +162,282 @@ func (s *Server) httpRoute(state *luajit.State) int {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) storeFunction() int {
|
func (s *Server) storeFunction(state *luajit.State) int {
|
||||||
s.state.GetGlobal("__moonshark_functions")
|
state.GetGlobal("__moonshark_functions")
|
||||||
if s.state.IsNil(-1) {
|
if state.IsNil(-1) {
|
||||||
s.state.Pop(1)
|
state.Pop(1)
|
||||||
s.state.NewTable()
|
state.NewTable()
|
||||||
s.state.PushCopy(-1)
|
state.PushCopy(-1)
|
||||||
s.state.SetGlobal("__moonshark_functions")
|
state.SetGlobal("__moonshark_functions")
|
||||||
}
|
}
|
||||||
|
|
||||||
s.funcCounter++
|
s.funcCounter++
|
||||||
s.state.PushNumber(float64(s.funcCounter))
|
state.PushNumber(float64(s.funcCounter))
|
||||||
s.state.PushCopy(-3)
|
state.PushCopy(-3)
|
||||||
s.state.SetTable(-3)
|
state.SetTable(-3)
|
||||||
s.state.Pop(2)
|
state.Pop(2)
|
||||||
|
|
||||||
return s.funcCounter
|
return s.funcCounter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getFunction(ref int) bool {
|
func (s *Server) getFunction(state *luajit.State, ref int) bool {
|
||||||
s.state.GetGlobal("__moonshark_functions")
|
state.GetGlobal("__moonshark_functions")
|
||||||
if s.state.IsNil(-1) {
|
if state.IsNil(-1) {
|
||||||
s.state.Pop(1)
|
state.Pop(1)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
s.state.PushNumber(float64(ref))
|
state.PushNumber(float64(ref))
|
||||||
s.state.GetTable(-2)
|
state.GetTable(-2)
|
||||||
isFunc := s.state.IsFunction(-1)
|
isFunc := state.IsFunction(-1)
|
||||||
if !isFunc {
|
if !isFunc {
|
||||||
s.state.Pop(2)
|
state.Pop(2)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
s.state.Remove(-2)
|
state.Remove(-2)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) requestHandler(ctx *fasthttp.RequestCtx) {
|
func (s *Server) fastRequestHandler(ctx *fasthttp.RequestCtx) {
|
||||||
method := string(ctx.Method())
|
method := string(ctx.Method())
|
||||||
path := string(ctx.Path())
|
path := string(ctx.Path())
|
||||||
|
|
||||||
// Look up route in router
|
// Fast route lookup
|
||||||
handlerRef, params, found := s.router.Lookup(method, path)
|
handlerID, params, found := s.router.Lookup(method, path)
|
||||||
if !found {
|
if !found {
|
||||||
ctx.SetStatusCode(404)
|
ctx.SetStatusCode(404)
|
||||||
ctx.SetBodyString("Not Found")
|
ctx.SetBodyString("Not Found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
reqCtx := s.buildRequestContext(ctx, params)
|
// Get compiled handler
|
||||||
reqCtx.Session.AdvanceFlash()
|
s.handlersMu.RLock()
|
||||||
|
handler := s.handlers[handlerID]
|
||||||
|
s.handlersMu.RUnlock()
|
||||||
|
|
||||||
s.stateMu.Lock()
|
if handler == nil {
|
||||||
defer s.stateMu.Unlock()
|
|
||||||
|
|
||||||
s.setupRequestEnvironment(reqCtx)
|
|
||||||
|
|
||||||
if !s.getFunction(handlerRef) {
|
|
||||||
ctx.SetStatusCode(500)
|
ctx.SetStatusCode(500)
|
||||||
ctx.SetBodyString("Handler not found")
|
ctx.SetBodyString("Handler not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.state.PushValue(s.requestToTable(reqCtx)); err != nil {
|
// Lock state for execution
|
||||||
ctx.SetStatusCode(500)
|
s.stateMu.Lock()
|
||||||
ctx.SetBodyString("Failed to create request object")
|
defer s.stateMu.Unlock()
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.state.Call(1, 1); err != nil {
|
// Setup request context
|
||||||
ctx.SetStatusCode(500)
|
reqCtx := &RequestContext{
|
||||||
ctx.SetBodyString(fmt.Sprintf("Handler error: %v", err))
|
ctx: ctx,
|
||||||
return
|
params: params,
|
||||||
|
session: s.sessions.GetSessionFromRequest(ctx),
|
||||||
}
|
}
|
||||||
|
reqCtx.session.AdvanceFlash()
|
||||||
|
|
||||||
var responseBody string
|
var responseBody string
|
||||||
if s.state.GetTop() > 0 && !s.state.IsNil(-1) {
|
|
||||||
responseBody = s.state.ToString(-1)
|
if handler.isFunction {
|
||||||
|
// Function handler - use traditional approach
|
||||||
|
s.setupFunctionEnvironment(s.state, reqCtx)
|
||||||
|
|
||||||
|
if !s.getFunction(s.state, handler.funcRef) {
|
||||||
|
ctx.SetStatusCode(500)
|
||||||
|
ctx.SetBodyString("Function handler not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Push request object
|
||||||
|
if err := s.state.PushValue(s.requestToTable(reqCtx)); err != nil {
|
||||||
|
ctx.SetStatusCode(500)
|
||||||
|
ctx.SetBodyString("Failed to create request object")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.state.Call(1, 1); err != nil {
|
||||||
|
ctx.SetStatusCode(500)
|
||||||
|
ctx.SetBodyString(fmt.Sprintf("Handler error: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.state.GetTop() > 0 && !s.state.IsNil(-1) {
|
||||||
|
responseBody = s.state.ToString(-1)
|
||||||
|
}
|
||||||
|
s.state.Pop(1)
|
||||||
|
} else {
|
||||||
|
// Bytecode handler - use fast approach
|
||||||
|
s.setupFastEnvironment(s.state, reqCtx)
|
||||||
|
|
||||||
|
if err := s.state.LoadAndRunBytecode(handler.bytecode, handler.name); err != nil {
|
||||||
|
ctx.SetStatusCode(500)
|
||||||
|
ctx.SetBodyString(fmt.Sprintf("Handler error: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.state.GetTop() > 0 && !s.state.IsNil(-1) {
|
||||||
|
responseBody = s.state.ToString(-1)
|
||||||
|
}
|
||||||
s.state.Pop(1)
|
s.state.Pop(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.updateSessionFromLua(reqCtx.Session)
|
// Apply response
|
||||||
s.applyResponse(ctx, responseBody)
|
s.applyResponse(ctx, s.state, responseBody, reqCtx.session)
|
||||||
s.sessions.ApplySessionCookie(ctx, reqCtx.Session)
|
s.sessions.ApplySessionCookie(ctx, reqCtx.session)
|
||||||
|
|
||||||
|
// Clean up state
|
||||||
|
s.state.SetTop(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) setupRequestEnvironment(reqCtx *RequestContext) {
|
func (s *Server) setupFunctionEnvironment(state *luajit.State, reqCtx *RequestContext) {
|
||||||
s.state.PushValue(s.requestToTable(reqCtx))
|
// Set up response globals for function handlers
|
||||||
s.state.SetGlobal("__request")
|
state.NewTable()
|
||||||
|
state.PushNumber(200)
|
||||||
|
state.SetField(-2, "status")
|
||||||
|
state.NewTable()
|
||||||
|
state.SetField(-2, "headers")
|
||||||
|
state.NewTable()
|
||||||
|
state.SetField(-2, "cookies")
|
||||||
|
state.SetGlobal("__response")
|
||||||
|
|
||||||
s.state.PushValue(s.sessionToTable(reqCtx.Session))
|
// Session data
|
||||||
s.state.SetGlobal("__session")
|
if !reqCtx.session.IsEmpty() {
|
||||||
|
state.PushValue(reqCtx.session.GetAll())
|
||||||
|
state.SetGlobal("__session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
s.state.NewTable()
|
func (s *Server) setupFastEnvironment(state *luajit.State, reqCtx *RequestContext) {
|
||||||
s.state.SetGlobal("__response")
|
// Request basics as globals for fast access
|
||||||
|
state.PushString(string(reqCtx.ctx.Method()))
|
||||||
|
state.SetGlobal("REQUEST_METHOD")
|
||||||
|
|
||||||
|
state.PushString(string(reqCtx.ctx.Path()))
|
||||||
|
state.SetGlobal("REQUEST_PATH")
|
||||||
|
|
||||||
|
// Parameters
|
||||||
|
if reqCtx.params != nil && len(reqCtx.params.Keys) > 0 {
|
||||||
|
paramMap := make(map[string]string, len(reqCtx.params.Keys))
|
||||||
|
for i, key := range reqCtx.params.Keys {
|
||||||
|
if i < len(reqCtx.params.Values) {
|
||||||
|
paramMap[key] = reqCtx.params.Values[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.PushValue(paramMap)
|
||||||
|
state.SetGlobal("PARAMS")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query parameters
|
||||||
|
queryMap := make(map[string]string)
|
||||||
|
reqCtx.ctx.QueryArgs().VisitAll(func(key, value []byte) {
|
||||||
|
queryMap[string(key)] = string(value)
|
||||||
|
})
|
||||||
|
if len(queryMap) > 0 {
|
||||||
|
state.PushValue(queryMap)
|
||||||
|
state.SetGlobal("QUERY")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Headers
|
||||||
|
headerMap := make(map[string]string)
|
||||||
|
reqCtx.ctx.Request.Header.VisitAll(func(key, value []byte) {
|
||||||
|
headerMap[string(key)] = string(value)
|
||||||
|
})
|
||||||
|
state.PushValue(headerMap)
|
||||||
|
state.SetGlobal("HEADERS")
|
||||||
|
|
||||||
|
// Cookies
|
||||||
|
cookieMap := make(map[string]string)
|
||||||
|
reqCtx.ctx.Request.Header.VisitAllCookie(func(key, value []byte) {
|
||||||
|
cookieMap[string(key)] = string(value)
|
||||||
|
})
|
||||||
|
if len(cookieMap) > 0 {
|
||||||
|
state.PushValue(cookieMap)
|
||||||
|
state.SetGlobal("COOKIES")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Form data
|
||||||
|
if reqCtx.ctx.IsPost() || reqCtx.ctx.IsPut() || reqCtx.ctx.IsPatch() {
|
||||||
|
form := s.parseForm(reqCtx.ctx)
|
||||||
|
if len(form) > 0 {
|
||||||
|
state.PushValue(form)
|
||||||
|
state.SetGlobal("FORM")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Session data
|
||||||
|
if !reqCtx.session.IsEmpty() {
|
||||||
|
state.PushValue(reqCtx.session.GetAll())
|
||||||
|
state.SetGlobal("session_data")
|
||||||
|
}
|
||||||
|
|
||||||
|
// CSRF token
|
||||||
|
if csrfToken := s.generateCSRFToken(); csrfToken != "" {
|
||||||
|
state.PushString(csrfToken)
|
||||||
|
state.SetGlobal("CSRF_TOKEN")
|
||||||
|
}
|
||||||
|
|
||||||
|
// JSON encode fallback
|
||||||
|
state.RegisterGoFunction("json_encode_fallback", func(state *luajit.State) int {
|
||||||
|
val, _ := state.ToValue(1)
|
||||||
|
if b, err := json.Marshal(val); err == nil {
|
||||||
|
state.PushString(string(b))
|
||||||
|
} else {
|
||||||
|
state.PushString("null")
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) requestToTable(reqCtx *RequestContext) map[string]any {
|
func (s *Server) requestToTable(reqCtx *RequestContext) map[string]any {
|
||||||
return map[string]any{
|
req := map[string]any{
|
||||||
"method": reqCtx.Method,
|
"method": string(reqCtx.ctx.Method()),
|
||||||
"path": reqCtx.Path,
|
"path": string(reqCtx.ctx.Path()),
|
||||||
"headers": reqCtx.Headers,
|
"headers": make(map[string]string),
|
||||||
"query": reqCtx.Query,
|
"query": make(map[string]string),
|
||||||
"form": reqCtx.Form,
|
"cookies": make(map[string]string),
|
||||||
"cookies": reqCtx.Cookies,
|
"body": string(reqCtx.ctx.PostBody()),
|
||||||
"body": reqCtx.Body,
|
|
||||||
"params": reqCtx.Params,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Headers
|
||||||
|
headers := req["headers"].(map[string]string)
|
||||||
|
reqCtx.ctx.Request.Header.VisitAll(func(key, value []byte) {
|
||||||
|
headers[string(key)] = string(value)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Cookies
|
||||||
|
cookies := req["cookies"].(map[string]string)
|
||||||
|
reqCtx.ctx.Request.Header.VisitAllCookie(func(key, value []byte) {
|
||||||
|
cookies[string(key)] = string(value)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Query
|
||||||
|
query := req["query"].(map[string]string)
|
||||||
|
reqCtx.ctx.QueryArgs().VisitAll(func(key, value []byte) {
|
||||||
|
query[string(key)] = string(value)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Params
|
||||||
|
if reqCtx.params != nil && len(reqCtx.params.Keys) > 0 {
|
||||||
|
params := make(map[string]string, len(reqCtx.params.Keys))
|
||||||
|
for i, key := range reqCtx.params.Keys {
|
||||||
|
if i < len(reqCtx.params.Values) {
|
||||||
|
params[key] = reqCtx.params.Values[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
req["params"] = params
|
||||||
|
}
|
||||||
|
|
||||||
|
// Form
|
||||||
|
if reqCtx.ctx.IsPost() || reqCtx.ctx.IsPut() || reqCtx.ctx.IsPatch() {
|
||||||
|
req["form"] = s.parseForm(reqCtx.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
return req
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) sessionToTable(session *sessions.Session) map[string]any {
|
func (s *Server) applyResponse(ctx *fasthttp.RequestCtx, state *luajit.State, body string, session *sessions.Session) {
|
||||||
return map[string]any{
|
// Update session from Lua
|
||||||
"id": session.ID,
|
state.GetGlobal("session_data")
|
||||||
"data": session.GetAll(),
|
if state.IsTable(-1) {
|
||||||
}
|
if data, err := state.ToTable(-1); err == nil {
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) updateSessionFromLua(session *sessions.Session) {
|
|
||||||
s.state.GetGlobal("__session")
|
|
||||||
if s.state.IsNil(-1) {
|
|
||||||
s.state.Pop(1)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.state.GetField(-1, "data")
|
|
||||||
if s.state.IsTable(-1) {
|
|
||||||
if data, err := s.state.ToTable(-1); err == nil {
|
|
||||||
if dataMap, ok := data.(map[string]any); ok {
|
if dataMap, ok := data.(map[string]any); ok {
|
||||||
session.Clear()
|
session.Clear()
|
||||||
for k, v := range dataMap {
|
for k, v := range dataMap {
|
||||||
@ -266,55 +446,59 @@ func (s *Server) updateSessionFromLua(session *sessions.Session) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.state.Pop(2)
|
state.Pop(1)
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) applyResponse(ctx *fasthttp.RequestCtx, body string) {
|
// Check for response table (function handlers) or response global (fast handlers)
|
||||||
s.state.GetGlobal("__response")
|
state.GetGlobal("__response")
|
||||||
if s.state.IsNil(-1) {
|
if state.IsNil(-1) {
|
||||||
s.state.Pop(1)
|
state.Pop(1)
|
||||||
if body != "" {
|
state.GetGlobal("response")
|
||||||
ctx.SetBodyString(body)
|
if state.IsNil(-1) {
|
||||||
|
state.Pop(1)
|
||||||
|
if body != "" {
|
||||||
|
ctx.SetBodyString(body)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.state.GetField(-1, "status")
|
// Status
|
||||||
if s.state.IsNumber(-1) {
|
if status := state.GetFieldNumber(-1, "status", 200); status != 200 {
|
||||||
ctx.SetStatusCode(int(s.state.ToNumber(-1)))
|
ctx.SetStatusCode(int(status))
|
||||||
}
|
}
|
||||||
s.state.Pop(1)
|
|
||||||
|
|
||||||
s.state.GetField(-1, "headers")
|
// Headers
|
||||||
if s.state.IsTable(-1) {
|
state.GetField(-1, "headers")
|
||||||
s.state.ForEachTableKV(-1, func(key, value string) bool {
|
if state.IsTable(-1) {
|
||||||
|
state.ForEachTableKV(-1, func(key, value string) bool {
|
||||||
ctx.Response.Header.Set(key, value)
|
ctx.Response.Header.Set(key, value)
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
s.state.Pop(1)
|
state.Pop(1)
|
||||||
|
|
||||||
s.state.GetField(-1, "cookies")
|
// Cookies
|
||||||
if s.state.IsTable(-1) {
|
state.GetField(-1, "cookies")
|
||||||
s.applyCookies(ctx)
|
if state.IsTable(-1) {
|
||||||
|
s.applyCookies(ctx, state)
|
||||||
}
|
}
|
||||||
s.state.Pop(1)
|
state.Pop(1)
|
||||||
|
|
||||||
s.state.Pop(1)
|
state.Pop(1)
|
||||||
|
|
||||||
if body != "" {
|
if body != "" {
|
||||||
ctx.SetBodyString(body)
|
ctx.SetBodyString(body)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) applyCookies(ctx *fasthttp.RequestCtx) {
|
func (s *Server) applyCookies(ctx *fasthttp.RequestCtx, state *luajit.State) {
|
||||||
s.state.ForEachArray(-1, func(i int, state *luajit.State) bool {
|
state.ForEachArray(-1, func(i int, st *luajit.State) bool {
|
||||||
if !state.IsTable(-1) {
|
if !st.IsTable(-1) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
name := state.GetFieldString(-1, "name", "")
|
name := st.GetFieldString(-1, "name", "")
|
||||||
value := state.GetFieldString(-1, "value", "")
|
value := st.GetFieldString(-1, "value", "")
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -324,22 +508,29 @@ func (s *Server) applyCookies(ctx *fasthttp.RequestCtx) {
|
|||||||
|
|
||||||
cookie.SetKey(name)
|
cookie.SetKey(name)
|
||||||
cookie.SetValue(value)
|
cookie.SetValue(value)
|
||||||
cookie.SetPath(state.GetFieldString(-1, "path", "/"))
|
|
||||||
|
|
||||||
if domain := state.GetFieldString(-1, "domain", ""); domain != "" {
|
if options, ok := st.GetFieldTable(-1, "options"); ok {
|
||||||
cookie.SetDomain(domain)
|
if optMap, ok := options.(map[string]any); ok {
|
||||||
}
|
if path, ok := optMap["path"].(string); ok {
|
||||||
|
cookie.SetPath(path)
|
||||||
if state.GetFieldBool(-1, "secure", false) {
|
} else {
|
||||||
cookie.SetSecure(true)
|
cookie.SetPath("/")
|
||||||
}
|
}
|
||||||
|
if domain, ok := optMap["domain"].(string); ok {
|
||||||
if state.GetFieldBool(-1, "http_only", true) {
|
cookie.SetDomain(domain)
|
||||||
cookie.SetHTTPOnly(true)
|
}
|
||||||
}
|
if secure, ok := optMap["secure"].(bool); ok && secure {
|
||||||
|
cookie.SetSecure(true)
|
||||||
if maxAge := state.GetFieldNumber(-1, "max_age", 0); maxAge > 0 {
|
}
|
||||||
cookie.SetExpire(time.Now().Add(time.Duration(maxAge) * time.Second))
|
if httpOnly, ok := optMap["http_only"].(bool); ok {
|
||||||
|
cookie.SetHTTPOnly(httpOnly)
|
||||||
|
} else {
|
||||||
|
cookie.SetHTTPOnly(true)
|
||||||
|
}
|
||||||
|
if maxAge, ok := optMap["max_age"].(int); ok && maxAge > 0 {
|
||||||
|
cookie.SetExpire(time.Now().Add(time.Duration(maxAge) * time.Second))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.Response.Header.SetCookie(cookie)
|
ctx.Response.Header.SetCookie(cookie)
|
||||||
@ -347,42 +538,6 @@ func (s *Server) applyCookies(ctx *fasthttp.RequestCtx) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) buildRequestContext(ctx *fasthttp.RequestCtx, params *router.Params) *RequestContext {
|
|
||||||
reqCtx := &RequestContext{
|
|
||||||
Method: string(ctx.Method()),
|
|
||||||
Path: string(ctx.Path()),
|
|
||||||
Headers: make(map[string]string),
|
|
||||||
Query: make(map[string]string),
|
|
||||||
Cookies: make(map[string]string),
|
|
||||||
Body: string(ctx.PostBody()),
|
|
||||||
Params: make(map[string]string),
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.Request.Header.VisitAll(func(key, value []byte) {
|
|
||||||
reqCtx.Headers[string(key)] = string(value)
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx.Request.Header.VisitAllCookie(func(key, value []byte) {
|
|
||||||
reqCtx.Cookies[string(key)] = string(value)
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx.QueryArgs().VisitAll(func(key, value []byte) {
|
|
||||||
reqCtx.Query[string(key)] = string(value)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Convert router params to map
|
|
||||||
for i, key := range params.Keys {
|
|
||||||
if i < len(params.Values) {
|
|
||||||
reqCtx.Params[key] = params.Values[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
reqCtx.Form = s.parseForm(ctx)
|
|
||||||
reqCtx.Session = s.sessions.GetSessionFromRequest(ctx)
|
|
||||||
|
|
||||||
return reqCtx
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) parseForm(ctx *fasthttp.RequestCtx) map[string]any {
|
func (s *Server) parseForm(ctx *fasthttp.RequestCtx) map[string]any {
|
||||||
contentType := string(ctx.Request.Header.ContentType())
|
contentType := string(ctx.Request.Header.ContentType())
|
||||||
form := make(map[string]any)
|
form := make(map[string]any)
|
||||||
@ -440,317 +595,3 @@ func (s *Server) generateCSRFToken() string {
|
|||||||
rand.Read(bytes)
|
rand.Read(bytes)
|
||||||
return base64.URLEncoding.EncodeToString(bytes)
|
return base64.URLEncoding.EncodeToString(bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lua function implementations
|
|
||||||
func httpSetStatus(state *luajit.State) int {
|
|
||||||
code, _ := state.SafeToNumber(1)
|
|
||||||
state.GetGlobal("__response")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
state.NewTable()
|
|
||||||
state.SetGlobal("__response")
|
|
||||||
state.GetGlobal("__response")
|
|
||||||
}
|
|
||||||
state.PushNumber(code)
|
|
||||||
state.SetField(-2, "status")
|
|
||||||
state.Pop(1)
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func httpSetHeader(state *luajit.State) int {
|
|
||||||
name, _ := state.SafeToString(1)
|
|
||||||
value, _ := state.SafeToString(2)
|
|
||||||
|
|
||||||
state.GetGlobal("__response")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
state.NewTable()
|
|
||||||
state.SetGlobal("__response")
|
|
||||||
state.GetGlobal("__response")
|
|
||||||
}
|
|
||||||
|
|
||||||
state.GetField(-1, "headers")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
state.NewTable()
|
|
||||||
state.PushCopy(-1)
|
|
||||||
state.SetField(-3, "headers")
|
|
||||||
}
|
|
||||||
|
|
||||||
state.PushString(value)
|
|
||||||
state.SetField(-2, name)
|
|
||||||
state.Pop(2)
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func httpRedirect(state *luajit.State) int {
|
|
||||||
url, _ := state.SafeToString(1)
|
|
||||||
status := 302.0
|
|
||||||
if state.GetTop() >= 2 {
|
|
||||||
status, _ = state.SafeToNumber(2)
|
|
||||||
}
|
|
||||||
|
|
||||||
state.GetGlobal("__response")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
state.NewTable()
|
|
||||||
state.SetGlobal("__response")
|
|
||||||
state.GetGlobal("__response")
|
|
||||||
}
|
|
||||||
|
|
||||||
state.PushNumber(status)
|
|
||||||
state.SetField(-2, "status")
|
|
||||||
|
|
||||||
state.GetField(-1, "headers")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
state.NewTable()
|
|
||||||
state.PushCopy(-1)
|
|
||||||
state.SetField(-3, "headers")
|
|
||||||
}
|
|
||||||
|
|
||||||
state.PushString(url)
|
|
||||||
state.SetField(-2, "Location")
|
|
||||||
state.Pop(2)
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) sessionGet(state *luajit.State) int {
|
|
||||||
key, _ := state.SafeToString(1)
|
|
||||||
|
|
||||||
state.GetGlobal("__session")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
state.PushNil()
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
state.GetField(-1, "data")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(2)
|
|
||||||
state.PushNil()
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
state.GetField(-1, key)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) sessionSet(state *luajit.State) int {
|
|
||||||
key, _ := state.SafeToString(1)
|
|
||||||
|
|
||||||
state.GetGlobal("__session")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
state.NewTable()
|
|
||||||
state.SetGlobal("__session")
|
|
||||||
state.GetGlobal("__session")
|
|
||||||
}
|
|
||||||
|
|
||||||
state.GetField(-1, "data")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
state.NewTable()
|
|
||||||
state.PushCopy(-1)
|
|
||||||
state.SetField(-3, "data")
|
|
||||||
}
|
|
||||||
|
|
||||||
value, err := state.ToValue(2)
|
|
||||||
if err == nil {
|
|
||||||
state.PushValue(value)
|
|
||||||
state.SetField(-2, key)
|
|
||||||
}
|
|
||||||
state.Pop(2)
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) sessionFlash(state *luajit.State) int {
|
|
||||||
key, _ := state.SafeToString(1)
|
|
||||||
|
|
||||||
state.GetGlobal("__session")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
state.NewTable()
|
|
||||||
state.SetGlobal("__session")
|
|
||||||
state.GetGlobal("__session")
|
|
||||||
}
|
|
||||||
|
|
||||||
state.GetField(-1, "flash")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
state.NewTable()
|
|
||||||
state.PushCopy(-1)
|
|
||||||
state.SetField(-3, "flash")
|
|
||||||
}
|
|
||||||
|
|
||||||
value, err := state.ToValue(2)
|
|
||||||
if err == nil {
|
|
||||||
state.PushValue(value)
|
|
||||||
state.SetField(-2, key)
|
|
||||||
}
|
|
||||||
state.Pop(2)
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) sessionGetFlash(state *luajit.State) int {
|
|
||||||
key, _ := state.SafeToString(1)
|
|
||||||
|
|
||||||
state.GetGlobal("__session")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
state.PushNil()
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
state.GetField(-1, "flash")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(2)
|
|
||||||
state.PushNil()
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
state.GetField(-1, key)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func cookieSet(state *luajit.State) int {
|
|
||||||
name, _ := state.SafeToString(1)
|
|
||||||
value, _ := state.SafeToString(2)
|
|
||||||
|
|
||||||
maxAge := 0
|
|
||||||
path := "/"
|
|
||||||
domain := ""
|
|
||||||
secure := false
|
|
||||||
httpOnly := true
|
|
||||||
|
|
||||||
if state.GetTop() >= 3 && state.IsTable(3) {
|
|
||||||
maxAge = int(state.GetFieldNumber(3, "max_age", 0))
|
|
||||||
path = state.GetFieldString(3, "path", "/")
|
|
||||||
domain = state.GetFieldString(3, "domain", "")
|
|
||||||
secure = state.GetFieldBool(3, "secure", false)
|
|
||||||
httpOnly = state.GetFieldBool(3, "http_only", true)
|
|
||||||
}
|
|
||||||
|
|
||||||
state.GetGlobal("__response")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
state.NewTable()
|
|
||||||
state.SetGlobal("__response")
|
|
||||||
state.GetGlobal("__response")
|
|
||||||
}
|
|
||||||
|
|
||||||
state.GetField(-1, "cookies")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
state.NewTable()
|
|
||||||
state.PushCopy(-1)
|
|
||||||
state.SetField(-3, "cookies")
|
|
||||||
}
|
|
||||||
|
|
||||||
cookieData := map[string]any{
|
|
||||||
"name": name,
|
|
||||||
"value": value,
|
|
||||||
"path": path,
|
|
||||||
"secure": secure,
|
|
||||||
"http_only": httpOnly,
|
|
||||||
}
|
|
||||||
if domain != "" {
|
|
||||||
cookieData["domain"] = domain
|
|
||||||
}
|
|
||||||
if maxAge > 0 {
|
|
||||||
cookieData["max_age"] = maxAge
|
|
||||||
}
|
|
||||||
|
|
||||||
state.PushValue(cookieData)
|
|
||||||
|
|
||||||
length := globalServer.getTableLength(-2)
|
|
||||||
state.PushNumber(float64(length + 1))
|
|
||||||
state.PushCopy(-2)
|
|
||||||
state.SetTable(-4)
|
|
||||||
|
|
||||||
state.Pop(3)
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func cookieGet(state *luajit.State) int {
|
|
||||||
name, _ := state.SafeToString(1)
|
|
||||||
|
|
||||||
state.GetGlobal("__request")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
state.PushNil()
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
state.GetField(-1, "cookies")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(2)
|
|
||||||
state.PushNil()
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
state.GetField(-1, name)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) csrfGenerate(state *luajit.State) int {
|
|
||||||
token := s.generateCSRFToken()
|
|
||||||
|
|
||||||
state.GetGlobal("__session")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
state.NewTable()
|
|
||||||
state.SetGlobal("__session")
|
|
||||||
state.GetGlobal("__session")
|
|
||||||
}
|
|
||||||
|
|
||||||
state.GetField(-1, "data")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
state.NewTable()
|
|
||||||
state.PushCopy(-1)
|
|
||||||
state.SetField(-3, "data")
|
|
||||||
}
|
|
||||||
|
|
||||||
state.PushString(token)
|
|
||||||
state.SetField(-2, "_csrf_token")
|
|
||||||
state.Pop(2)
|
|
||||||
|
|
||||||
state.PushString(token)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) csrfValidate(state *luajit.State) int {
|
|
||||||
state.GetGlobal("__session")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
state.PushBoolean(false)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionToken := state.GetFieldString(-1, "data._csrf_token", "")
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
state.GetGlobal("__request")
|
|
||||||
if state.IsNil(-1) {
|
|
||||||
state.Pop(1)
|
|
||||||
state.PushBoolean(false)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
requestToken := state.GetFieldString(-1, "form._csrf_token", "")
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
state.PushBoolean(sessionToken != "" && sessionToken == requestToken)
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) getTableLength(index int) int {
|
|
||||||
length := 0
|
|
||||||
s.state.PushNil()
|
|
||||||
for s.state.Next(index - 1) {
|
|
||||||
length++
|
|
||||||
s.state.Pop(1)
|
|
||||||
}
|
|
||||||
return length
|
|
||||||
}
|
|
||||||
|
130
http/http.lua
130
http/http.lua
@ -1,3 +1,7 @@
|
|||||||
|
-- Fast response handling
|
||||||
|
local response = {status = 200, headers = {}, cookies = {}}
|
||||||
|
local session_data = {}
|
||||||
|
|
||||||
http = {}
|
http = {}
|
||||||
|
|
||||||
function http.listen(port)
|
function http.listen(port)
|
||||||
@ -9,16 +13,11 @@ function http.route(method, path, handler)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function http.status(code)
|
function http.status(code)
|
||||||
return __http_set_status(code)
|
response.status = code
|
||||||
end
|
end
|
||||||
|
|
||||||
function http.header(name, value)
|
function http.header(k, v)
|
||||||
return __http_set_header(name, value)
|
response.headers[k] = v
|
||||||
end
|
|
||||||
|
|
||||||
function http.redirect(url, status)
|
|
||||||
__http_redirect(url, status or 302)
|
|
||||||
coroutine.yield() -- Exit handler
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function http.json(data)
|
function http.json(data)
|
||||||
@ -36,80 +35,115 @@ function http.text(content)
|
|||||||
return content
|
return content
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function http.redirect(url, code)
|
||||||
|
response.status = code or 302
|
||||||
|
response.headers["Location"] = url
|
||||||
|
coroutine.yield()
|
||||||
|
end
|
||||||
|
|
||||||
-- Session functions
|
-- Session functions
|
||||||
session = {}
|
session = {}
|
||||||
|
|
||||||
function session.get(key)
|
function session.get(key)
|
||||||
return __session_get(key)
|
return session_data[key]
|
||||||
end
|
end
|
||||||
|
|
||||||
function session.set(key, value)
|
function session.set(key, val)
|
||||||
return __session_set(key, value)
|
session_data[key] = val
|
||||||
end
|
end
|
||||||
|
|
||||||
function session.flash(key, value)
|
function session.flash(key, val)
|
||||||
return __session_flash(key, value)
|
session_data["_flash_" .. key] = val
|
||||||
end
|
end
|
||||||
|
|
||||||
function session.get_flash(key)
|
function session.get_flash(key)
|
||||||
return __session_get_flash(key)
|
local val = session_data["_flash_" .. key]
|
||||||
|
session_data["_flash_" .. key] = nil
|
||||||
|
return val
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Cookie functions
|
-- Cookie functions
|
||||||
cookie = {}
|
cookie = {}
|
||||||
|
|
||||||
function cookie.set(name, value, options)
|
function cookie.get(name)
|
||||||
return __cookie_set(name, value, options)
|
return COOKIES and COOKIES[name]
|
||||||
end
|
end
|
||||||
|
|
||||||
function cookie.get(name)
|
function cookie.set(name, value, options)
|
||||||
return __cookie_get(name)
|
response.cookies[#response.cookies + 1] = {
|
||||||
|
name = name,
|
||||||
|
value = value,
|
||||||
|
options = options or {}
|
||||||
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
-- CSRF functions
|
-- CSRF functions
|
||||||
csrf = {}
|
csrf = {}
|
||||||
|
|
||||||
function csrf.generate()
|
function csrf.generate()
|
||||||
return __csrf_generate()
|
local token = CSRF_TOKEN or ""
|
||||||
|
session.set("_csrf_token", token)
|
||||||
|
return token
|
||||||
end
|
end
|
||||||
|
|
||||||
function csrf.validate()
|
function csrf.validate()
|
||||||
return __csrf_validate()
|
local session_token = session.get("_csrf_token")
|
||||||
|
local form_token = FORM and FORM._csrf_token
|
||||||
|
return session_token and session_token == form_token
|
||||||
end
|
end
|
||||||
|
|
||||||
function csrf.field()
|
function csrf.field()
|
||||||
local token = csrf.generate()
|
return '<input type="hidden" name="_csrf_token" value="' .. csrf.generate() .. '" />'
|
||||||
return string.format('<input type="hidden" name="_csrf_token" value="%s" />', token)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- Fast JSON encoding
|
||||||
|
json = {
|
||||||
|
encode = function(data)
|
||||||
|
if type(data) == "string" then
|
||||||
|
return '"' .. data .. '"'
|
||||||
|
elseif type(data) == "number" then
|
||||||
|
return tostring(data)
|
||||||
|
elseif type(data) == "boolean" then
|
||||||
|
return data and "true" or "false"
|
||||||
|
elseif data == nil then
|
||||||
|
return "null"
|
||||||
|
elseif type(data) == "table" then
|
||||||
|
-- Check if it's an array
|
||||||
|
local isArray = true
|
||||||
|
local n = 0
|
||||||
|
for k, v in pairs(data) do
|
||||||
|
n = n + 1
|
||||||
|
if type(k) ~= "number" or k ~= n then
|
||||||
|
isArray = false
|
||||||
|
break
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
if isArray then
|
||||||
|
local result = "["
|
||||||
|
for i = 1, n do
|
||||||
|
if i > 1 then result = result .. "," end
|
||||||
|
result = result .. json.encode(data[i])
|
||||||
|
end
|
||||||
|
return result .. "]"
|
||||||
|
else
|
||||||
|
local result = "{"
|
||||||
|
local first = true
|
||||||
|
for k, v in pairs(data) do
|
||||||
|
if not first then result = result .. "," end
|
||||||
|
result = result .. '"' .. tostring(k) .. '":' .. json.encode(v)
|
||||||
|
first = false
|
||||||
|
end
|
||||||
|
return result .. "}"
|
||||||
|
end
|
||||||
|
else
|
||||||
|
return json_encode_fallback(data)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
}
|
||||||
|
|
||||||
-- Helper functions
|
-- Helper functions
|
||||||
function redirect_with_flash(url, type, message)
|
function redirect_with_flash(url, type, message)
|
||||||
session.flash(type, message)
|
session.flash(type, message)
|
||||||
http.redirect(url)
|
http.redirect(url)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- JSON encoding/decoding placeholder
|
|
||||||
json = {
|
|
||||||
encode = function(data)
|
|
||||||
-- Simplified JSON encoding
|
|
||||||
if type(data) == "table" then
|
|
||||||
local result = "{"
|
|
||||||
local first = true
|
|
||||||
for k, v in pairs(data) do
|
|
||||||
if not first then result = result .. "," end
|
|
||||||
result = result .. '"' .. tostring(k) .. '":' .. json.encode(v)
|
|
||||||
first = false
|
|
||||||
end
|
|
||||||
return result .. "}"
|
|
||||||
elseif type(data) == "string" then
|
|
||||||
return '"' .. data .. '"'
|
|
||||||
else
|
|
||||||
return tostring(data)
|
|
||||||
end
|
|
||||||
end,
|
|
||||||
|
|
||||||
decode = function(str)
|
|
||||||
-- Simplified JSON decoding - you'd want a proper implementation
|
|
||||||
return {}
|
|
||||||
end
|
|
||||||
}
|
|
@ -1,25 +1,24 @@
|
|||||||
package router
|
package router
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"fmt"
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// node represents a node in the radix trie
|
// Handler function that takes parameters as strings
|
||||||
|
type Handler func(params []string)
|
||||||
|
|
||||||
type node struct {
|
type node struct {
|
||||||
segment string
|
segment string
|
||||||
handler int // Lua function reference
|
handlerID int
|
||||||
children []*node
|
children []*node
|
||||||
isDynamic bool // :param
|
isDynamic bool
|
||||||
isWildcard bool // *param
|
isWildcard bool
|
||||||
paramName string
|
maxParams uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
// Router is a string-based HTTP router with efficient lookup
|
|
||||||
type Router struct {
|
type Router struct {
|
||||||
get, post, put, patch, delete *node
|
get, post, put, patch, delete *node
|
||||||
mu sync.RWMutex
|
paramsBuffer []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Params holds URL parameters
|
// Params holds URL parameters
|
||||||
@ -41,15 +40,15 @@ func (p *Params) Get(name string) string {
|
|||||||
// New creates a new Router instance
|
// New creates a new Router instance
|
||||||
func New() *Router {
|
func New() *Router {
|
||||||
return &Router{
|
return &Router{
|
||||||
get: &node{},
|
get: &node{},
|
||||||
post: &node{},
|
post: &node{},
|
||||||
put: &node{},
|
put: &node{},
|
||||||
patch: &node{},
|
patch: &node{},
|
||||||
delete: &node{},
|
delete: &node{},
|
||||||
|
paramsBuffer: make([]string, 64),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// methodNode returns the root node for a method
|
|
||||||
func (r *Router) methodNode(method string) *node {
|
func (r *Router) methodNode(method string) *node {
|
||||||
switch method {
|
switch method {
|
||||||
case "GET":
|
case "GET":
|
||||||
@ -67,47 +66,71 @@ func (r *Router) methodNode(method string) *node {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddRoute adds a new route with handler reference
|
// AddRoute adds a route with handler ID (for compatibility)
|
||||||
func (r *Router) AddRoute(method, path string, handlerRef int) error {
|
func (r *Router) AddRoute(method, path string, handlerID int) error {
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
|
|
||||||
root := r.methodNode(method)
|
root := r.methodNode(method)
|
||||||
if root == nil {
|
if root == nil {
|
||||||
return errors.New("unsupported HTTP method")
|
return fmt.Errorf("unsupported method: %s", method)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create a handler that stores the ID
|
||||||
|
h := func(params []string) {
|
||||||
|
// This is a placeholder - the actual execution happens in HTTP module
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.addRoute(root, path, h, handlerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// readSegment extracts the next path segment
|
||||||
|
func readSegment(path string, start int) (segment string, end int, hasMore bool) {
|
||||||
|
if start >= len(path) {
|
||||||
|
return "", start, false
|
||||||
|
}
|
||||||
|
if path[start] == '/' {
|
||||||
|
start++
|
||||||
|
}
|
||||||
|
if start >= len(path) {
|
||||||
|
return "", start, false
|
||||||
|
}
|
||||||
|
end = start
|
||||||
|
for end < len(path) && path[end] != '/' {
|
||||||
|
end++
|
||||||
|
}
|
||||||
|
return path[start:end], end, end < len(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRoute adds a new route to the trie
|
||||||
|
func (r *Router) addRoute(root *node, path string, h Handler, handlerID int) error {
|
||||||
if path == "/" {
|
if path == "/" {
|
||||||
root.handler = handlerRef
|
root.handlerID = handlerID
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.addRoute(root, path, handlerRef)
|
|
||||||
}
|
|
||||||
|
|
||||||
// addRoute adds a route to the trie
|
|
||||||
func (r *Router) addRoute(root *node, path string, handlerRef int) error {
|
|
||||||
segments := r.parseSegments(path)
|
|
||||||
current := root
|
current := root
|
||||||
|
pos := 0
|
||||||
|
lastWC := false
|
||||||
|
count := uint8(0)
|
||||||
|
|
||||||
for _, seg := range segments {
|
for {
|
||||||
isDyn := strings.HasPrefix(seg, ":")
|
seg, newPos, more := readSegment(path, pos)
|
||||||
isWC := strings.HasPrefix(seg, "*")
|
if seg == "" {
|
||||||
|
break
|
||||||
if isWC && seg != segments[len(segments)-1] {
|
|
||||||
return errors.New("wildcard must be the last segment")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
paramName := ""
|
isDyn := len(seg) > 1 && seg[0] == ':'
|
||||||
if isDyn {
|
isWC := len(seg) > 0 && seg[0] == '*'
|
||||||
paramName = seg[1:]
|
|
||||||
seg = ":"
|
if isWC {
|
||||||
} else if isWC {
|
if lastWC || more {
|
||||||
paramName = seg[1:]
|
return fmt.Errorf("wildcard must be the last segment in the path")
|
||||||
seg = "*"
|
}
|
||||||
|
lastWC = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if isDyn || isWC {
|
||||||
|
count++
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find or create child
|
|
||||||
var child *node
|
var child *node
|
||||||
for _, c := range current.children {
|
for _, c := range current.children {
|
||||||
if c.segment == seg {
|
if c.segment == seg {
|
||||||
@ -117,144 +140,102 @@ func (r *Router) addRoute(root *node, path string, handlerRef int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if child == nil {
|
if child == nil {
|
||||||
child = &node{
|
child = &node{segment: seg, isDynamic: isDyn, isWildcard: isWC}
|
||||||
segment: seg,
|
|
||||||
isDynamic: isDyn,
|
|
||||||
isWildcard: isWC,
|
|
||||||
paramName: paramName,
|
|
||||||
}
|
|
||||||
current.children = append(current.children, child)
|
current.children = append(current.children, child)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if child.maxParams < count {
|
||||||
|
child.maxParams = count
|
||||||
|
}
|
||||||
|
|
||||||
current = child
|
current = child
|
||||||
|
pos = newPos
|
||||||
}
|
}
|
||||||
|
|
||||||
current.handler = handlerRef
|
current.handlerID = handlerID
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseSegments splits path into segments
|
// Lookup finds a handler matching method and path
|
||||||
func (r *Router) parseSegments(path string) []string {
|
|
||||||
segments := strings.Split(strings.Trim(path, "/"), "/")
|
|
||||||
var result []string
|
|
||||||
for _, seg := range segments {
|
|
||||||
if seg != "" {
|
|
||||||
result = append(result, seg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lookup finds handler and parameters for a method and path
|
|
||||||
func (r *Router) Lookup(method, path string) (int, *Params, bool) {
|
func (r *Router) Lookup(method, path string) (int, *Params, bool) {
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
root := r.methodNode(method)
|
root := r.methodNode(method)
|
||||||
if root == nil {
|
if root == nil {
|
||||||
return 0, nil, false
|
return 0, nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if path == "/" {
|
if path == "/" {
|
||||||
if root.handler != 0 {
|
if root.handlerID != 0 {
|
||||||
return root.handler, &Params{}, true
|
return root.handlerID, &Params{}, true
|
||||||
}
|
}
|
||||||
return 0, nil, false
|
return 0, nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
segments := r.parseSegments(path)
|
buffer := r.paramsBuffer
|
||||||
handler, params := r.match(root, segments, 0)
|
if cap(buffer) < int(root.maxParams) {
|
||||||
if handler == 0 {
|
buffer = make([]string, root.maxParams)
|
||||||
|
r.paramsBuffer = buffer
|
||||||
|
}
|
||||||
|
buffer = buffer[:0]
|
||||||
|
|
||||||
|
handlerID, paramCount, paramKeys, found := r.match(root, path, 0, &buffer)
|
||||||
|
if !found {
|
||||||
return 0, nil, false
|
return 0, nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
return handler, params, true
|
params := &Params{
|
||||||
}
|
Keys: paramKeys,
|
||||||
|
Values: buffer[:paramCount],
|
||||||
// match traverses the trie to find handler
|
|
||||||
func (r *Router) match(current *node, segments []string, index int) (int, *Params) {
|
|
||||||
if index >= len(segments) {
|
|
||||||
if current.handler != 0 {
|
|
||||||
return current.handler, &Params{}
|
|
||||||
}
|
|
||||||
return 0, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
segment := segments[index]
|
return handlerID, params, true
|
||||||
|
}
|
||||||
|
|
||||||
// Check exact match first
|
// match traverses the trie to find a handler
|
||||||
for _, child := range current.children {
|
func (r *Router) match(current *node, path string, start int, params *[]string) (int, int, []string, bool) {
|
||||||
if child.segment == segment {
|
paramCount := 0
|
||||||
handler, params := r.match(child, segments, index+1)
|
var paramKeys []string
|
||||||
if handler != 0 {
|
|
||||||
return handler, params
|
// Check wildcards first
|
||||||
|
for _, c := range current.children {
|
||||||
|
if c.isWildcard {
|
||||||
|
rem := path[start:]
|
||||||
|
if len(rem) > 0 && rem[0] == '/' {
|
||||||
|
rem = rem[1:]
|
||||||
|
}
|
||||||
|
*params = append(*params, rem)
|
||||||
|
// Extract param name from *name format
|
||||||
|
paramName := c.segment[1:] // Remove the * prefix
|
||||||
|
paramKeys = append(paramKeys, paramName)
|
||||||
|
return c.handlerID, 1, paramKeys, c.handlerID != 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
seg, pos, more := readSegment(path, start)
|
||||||
|
if seg == "" {
|
||||||
|
return current.handlerID, 0, paramKeys, current.handlerID != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range current.children {
|
||||||
|
if c.segment == seg || c.isDynamic {
|
||||||
|
if c.isDynamic {
|
||||||
|
*params = append(*params, seg)
|
||||||
|
// Extract param name from :name format
|
||||||
|
paramName := c.segment[1:] // Remove the : prefix
|
||||||
|
paramKeys = append(paramKeys, paramName)
|
||||||
|
paramCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
if !more {
|
||||||
|
return c.handlerID, paramCount, paramKeys, c.handlerID != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
handlerID, nestedCount, nestedKeys, ok := r.match(c, path, pos, params)
|
||||||
|
if ok {
|
||||||
|
allKeys := append(paramKeys, nestedKeys...)
|
||||||
|
return handlerID, paramCount + nestedCount, allKeys, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check dynamic match second
|
return 0, 0, nil, false
|
||||||
for _, child := range current.children {
|
|
||||||
if child.isDynamic {
|
|
||||||
handler, params := r.match(child, segments, index+1)
|
|
||||||
if handler != 0 {
|
|
||||||
// Prepend this parameter
|
|
||||||
newParams := &Params{
|
|
||||||
Keys: append([]string{child.paramName}, params.Keys...),
|
|
||||||
Values: append([]string{segment}, params.Values...),
|
|
||||||
}
|
|
||||||
return handler, newParams
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check wildcard last (catches everything remaining)
|
|
||||||
for _, child := range current.children {
|
|
||||||
if child.isWildcard {
|
|
||||||
remaining := strings.Join(segments[index:], "/")
|
|
||||||
return child.handler, &Params{
|
|
||||||
Keys: []string{child.paramName},
|
|
||||||
Values: []string{remaining},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveRoute removes a route
|
|
||||||
func (r *Router) RemoveRoute(method, path string) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
|
|
||||||
root := r.methodNode(method)
|
|
||||||
if root == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if path == "/" {
|
|
||||||
root.handler = 0
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
segments := r.parseSegments(path)
|
|
||||||
r.removeRoute(root, segments, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeRoute removes a route from the trie
|
|
||||||
func (r *Router) removeRoute(current *node, segments []string, index int) {
|
|
||||||
if index >= len(segments) {
|
|
||||||
current.handler = 0
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
segment := segments[index]
|
|
||||||
|
|
||||||
for _, child := range current.children {
|
|
||||||
if child.segment == segment ||
|
|
||||||
(child.isDynamic && strings.HasPrefix(segment, ":")) ||
|
|
||||||
(child.isWildcard && strings.HasPrefix(segment, "*")) {
|
|
||||||
r.removeRoute(child, segments, index+1)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user