next pass
This commit is contained in:
parent
bb06e2431d
commit
843e318e01
843
http/http.go
843
http/http.go
@ -21,25 +21,33 @@ import (
|
||||
//go:embed http.lua
|
||||
var httpLuaCode string
|
||||
|
||||
// HandlerFunc represents a Lua handler
|
||||
type HandlerFunc struct {
|
||||
bytecode []byte
|
||||
funcRef int
|
||||
name string
|
||||
isFunction bool
|
||||
}
|
||||
|
||||
// Server with single state for function handler compatibility
|
||||
type Server struct {
|
||||
server *fasthttp.Server
|
||||
router *router.Router
|
||||
sessions *sessions.SessionManager
|
||||
state *luajit.State
|
||||
stateMu sync.Mutex
|
||||
handlers map[int]*HandlerFunc
|
||||
handlersMu sync.RWMutex
|
||||
funcCounter int
|
||||
}
|
||||
|
||||
// RequestContext with lazy parsing
|
||||
type RequestContext struct {
|
||||
Method string
|
||||
Path string
|
||||
Headers map[string]string
|
||||
Query map[string]string
|
||||
Form map[string]any
|
||||
Cookies map[string]string
|
||||
Session *sessions.Session
|
||||
Body string
|
||||
Params map[string]string
|
||||
ctx *fasthttp.RequestCtx
|
||||
params *router.Params
|
||||
session *sessions.Session
|
||||
parsedForm map[string]any
|
||||
formOnce sync.Once
|
||||
}
|
||||
|
||||
var globalServer *Server
|
||||
@ -49,6 +57,7 @@ func NewServer(state *luajit.State) *Server {
|
||||
router: router.New(),
|
||||
sessions: sessions.NewSessionManager(10000),
|
||||
state: state,
|
||||
handlers: make(map[int]*HandlerFunc),
|
||||
}
|
||||
}
|
||||
|
||||
@ -56,19 +65,8 @@ func RegisterHTTPFunctions(L *luajit.State) error {
|
||||
globalServer = NewServer(L)
|
||||
|
||||
functions := map[string]luajit.GoFunction{
|
||||
"__http_listen": globalServer.httpListen,
|
||||
"__http_route": globalServer.httpRoute,
|
||||
"__http_set_status": httpSetStatus,
|
||||
"__http_set_header": httpSetHeader,
|
||||
"__http_redirect": httpRedirect,
|
||||
"__session_get": globalServer.sessionGet,
|
||||
"__session_set": globalServer.sessionSet,
|
||||
"__session_flash": globalServer.sessionFlash,
|
||||
"__session_get_flash": globalServer.sessionGetFlash,
|
||||
"__cookie_set": cookieSet,
|
||||
"__cookie_get": cookieGet,
|
||||
"__csrf_generate": globalServer.csrfGenerate,
|
||||
"__csrf_validate": globalServer.csrfValidate,
|
||||
"__http_listen": globalServer.httpListen,
|
||||
"__http_route": globalServer.httpRoute,
|
||||
}
|
||||
|
||||
for name, fn := range functions {
|
||||
@ -87,8 +85,13 @@ func (s *Server) httpListen(state *luajit.State) int {
|
||||
}
|
||||
|
||||
s.server = &fasthttp.Server{
|
||||
Handler: s.requestHandler,
|
||||
Name: "Moonshark/1.0",
|
||||
Handler: s.fastRequestHandler,
|
||||
Name: "Moonshark/2.0",
|
||||
Concurrency: 256 * 1024,
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
ReduceMemoryUsage: true,
|
||||
NoDefaultServerHeader: true,
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf(":%d", int(port))
|
||||
@ -114,16 +117,44 @@ func (s *Server) httpRoute(state *luajit.State) int {
|
||||
return state.PushError("route: path must be string")
|
||||
}
|
||||
|
||||
if !state.IsFunction(3) {
|
||||
return state.PushError("route: handler must be function")
|
||||
s.funcCounter++
|
||||
handlerID := s.funcCounter
|
||||
|
||||
if state.IsFunction(3) {
|
||||
// Function handler - store reference
|
||||
state.PushCopy(3)
|
||||
funcRef := s.storeFunction(state)
|
||||
|
||||
s.handlersMu.Lock()
|
||||
s.handlers[handlerID] = &HandlerFunc{
|
||||
funcRef: funcRef,
|
||||
name: fmt.Sprintf("%s %s", method, path),
|
||||
isFunction: true,
|
||||
}
|
||||
s.handlersMu.Unlock()
|
||||
} else {
|
||||
// String handler - compile to bytecode
|
||||
handlerCode, err := state.SafeToString(3)
|
||||
if err != nil {
|
||||
return state.PushError("route: handler must be function or string")
|
||||
}
|
||||
|
||||
bytecode, err := state.CompileBytecode(handlerCode, fmt.Sprintf("handler_%s_%s", method, path))
|
||||
if err != nil {
|
||||
return state.PushError("route: failed to compile handler: %s", err.Error())
|
||||
}
|
||||
|
||||
s.handlersMu.Lock()
|
||||
s.handlers[handlerID] = &HandlerFunc{
|
||||
bytecode: bytecode,
|
||||
name: fmt.Sprintf("%s %s", method, path),
|
||||
isFunction: false,
|
||||
}
|
||||
s.handlersMu.Unlock()
|
||||
}
|
||||
|
||||
// Store function and get reference
|
||||
state.PushCopy(3)
|
||||
funcRef := s.storeFunction()
|
||||
|
||||
// Add route to router
|
||||
if err := s.router.AddRoute(strings.ToUpper(method), path, funcRef); err != nil {
|
||||
if err := s.router.AddRoute(strings.ToUpper(method), path, handlerID); err != nil {
|
||||
return state.PushError("route: failed to add route: %s", err.Error())
|
||||
}
|
||||
|
||||
@ -131,133 +162,282 @@ func (s *Server) httpRoute(state *luajit.State) int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s *Server) storeFunction() int {
|
||||
s.state.GetGlobal("__moonshark_functions")
|
||||
if s.state.IsNil(-1) {
|
||||
s.state.Pop(1)
|
||||
s.state.NewTable()
|
||||
s.state.PushCopy(-1)
|
||||
s.state.SetGlobal("__moonshark_functions")
|
||||
func (s *Server) storeFunction(state *luajit.State) int {
|
||||
state.GetGlobal("__moonshark_functions")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.PushCopy(-1)
|
||||
state.SetGlobal("__moonshark_functions")
|
||||
}
|
||||
|
||||
s.funcCounter++
|
||||
s.state.PushNumber(float64(s.funcCounter))
|
||||
s.state.PushCopy(-3)
|
||||
s.state.SetTable(-3)
|
||||
s.state.Pop(2)
|
||||
state.PushNumber(float64(s.funcCounter))
|
||||
state.PushCopy(-3)
|
||||
state.SetTable(-3)
|
||||
state.Pop(2)
|
||||
|
||||
return s.funcCounter
|
||||
}
|
||||
|
||||
func (s *Server) getFunction(ref int) bool {
|
||||
s.state.GetGlobal("__moonshark_functions")
|
||||
if s.state.IsNil(-1) {
|
||||
s.state.Pop(1)
|
||||
func (s *Server) getFunction(state *luajit.State, ref int) bool {
|
||||
state.GetGlobal("__moonshark_functions")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
return false
|
||||
}
|
||||
|
||||
s.state.PushNumber(float64(ref))
|
||||
s.state.GetTable(-2)
|
||||
isFunc := s.state.IsFunction(-1)
|
||||
state.PushNumber(float64(ref))
|
||||
state.GetTable(-2)
|
||||
isFunc := state.IsFunction(-1)
|
||||
if !isFunc {
|
||||
s.state.Pop(2)
|
||||
state.Pop(2)
|
||||
return false
|
||||
}
|
||||
|
||||
s.state.Remove(-2)
|
||||
state.Remove(-2)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
func (s *Server) fastRequestHandler(ctx *fasthttp.RequestCtx) {
|
||||
method := string(ctx.Method())
|
||||
path := string(ctx.Path())
|
||||
|
||||
// Look up route in router
|
||||
handlerRef, params, found := s.router.Lookup(method, path)
|
||||
// Fast route lookup
|
||||
handlerID, params, found := s.router.Lookup(method, path)
|
||||
if !found {
|
||||
ctx.SetStatusCode(404)
|
||||
ctx.SetBodyString("Not Found")
|
||||
return
|
||||
}
|
||||
|
||||
reqCtx := s.buildRequestContext(ctx, params)
|
||||
reqCtx.Session.AdvanceFlash()
|
||||
// Get compiled handler
|
||||
s.handlersMu.RLock()
|
||||
handler := s.handlers[handlerID]
|
||||
s.handlersMu.RUnlock()
|
||||
|
||||
s.stateMu.Lock()
|
||||
defer s.stateMu.Unlock()
|
||||
|
||||
s.setupRequestEnvironment(reqCtx)
|
||||
|
||||
if !s.getFunction(handlerRef) {
|
||||
if handler == nil {
|
||||
ctx.SetStatusCode(500)
|
||||
ctx.SetBodyString("Handler not found")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.state.PushValue(s.requestToTable(reqCtx)); err != nil {
|
||||
ctx.SetStatusCode(500)
|
||||
ctx.SetBodyString("Failed to create request object")
|
||||
return
|
||||
}
|
||||
// Lock state for execution
|
||||
s.stateMu.Lock()
|
||||
defer s.stateMu.Unlock()
|
||||
|
||||
if err := s.state.Call(1, 1); err != nil {
|
||||
ctx.SetStatusCode(500)
|
||||
ctx.SetBodyString(fmt.Sprintf("Handler error: %v", err))
|
||||
return
|
||||
// Setup request context
|
||||
reqCtx := &RequestContext{
|
||||
ctx: ctx,
|
||||
params: params,
|
||||
session: s.sessions.GetSessionFromRequest(ctx),
|
||||
}
|
||||
reqCtx.session.AdvanceFlash()
|
||||
|
||||
var responseBody string
|
||||
if s.state.GetTop() > 0 && !s.state.IsNil(-1) {
|
||||
responseBody = s.state.ToString(-1)
|
||||
|
||||
if handler.isFunction {
|
||||
// Function handler - use traditional approach
|
||||
s.setupFunctionEnvironment(s.state, reqCtx)
|
||||
|
||||
if !s.getFunction(s.state, handler.funcRef) {
|
||||
ctx.SetStatusCode(500)
|
||||
ctx.SetBodyString("Function handler not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Push request object
|
||||
if err := s.state.PushValue(s.requestToTable(reqCtx)); err != nil {
|
||||
ctx.SetStatusCode(500)
|
||||
ctx.SetBodyString("Failed to create request object")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.state.Call(1, 1); err != nil {
|
||||
ctx.SetStatusCode(500)
|
||||
ctx.SetBodyString(fmt.Sprintf("Handler error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if s.state.GetTop() > 0 && !s.state.IsNil(-1) {
|
||||
responseBody = s.state.ToString(-1)
|
||||
}
|
||||
s.state.Pop(1)
|
||||
} else {
|
||||
// Bytecode handler - use fast approach
|
||||
s.setupFastEnvironment(s.state, reqCtx)
|
||||
|
||||
if err := s.state.LoadAndRunBytecode(handler.bytecode, handler.name); err != nil {
|
||||
ctx.SetStatusCode(500)
|
||||
ctx.SetBodyString(fmt.Sprintf("Handler error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if s.state.GetTop() > 0 && !s.state.IsNil(-1) {
|
||||
responseBody = s.state.ToString(-1)
|
||||
}
|
||||
s.state.Pop(1)
|
||||
}
|
||||
|
||||
s.updateSessionFromLua(reqCtx.Session)
|
||||
s.applyResponse(ctx, responseBody)
|
||||
s.sessions.ApplySessionCookie(ctx, reqCtx.Session)
|
||||
// Apply response
|
||||
s.applyResponse(ctx, s.state, responseBody, reqCtx.session)
|
||||
s.sessions.ApplySessionCookie(ctx, reqCtx.session)
|
||||
|
||||
// Clean up state
|
||||
s.state.SetTop(0)
|
||||
}
|
||||
|
||||
func (s *Server) setupRequestEnvironment(reqCtx *RequestContext) {
|
||||
s.state.PushValue(s.requestToTable(reqCtx))
|
||||
s.state.SetGlobal("__request")
|
||||
func (s *Server) setupFunctionEnvironment(state *luajit.State, reqCtx *RequestContext) {
|
||||
// Set up response globals for function handlers
|
||||
state.NewTable()
|
||||
state.PushNumber(200)
|
||||
state.SetField(-2, "status")
|
||||
state.NewTable()
|
||||
state.SetField(-2, "headers")
|
||||
state.NewTable()
|
||||
state.SetField(-2, "cookies")
|
||||
state.SetGlobal("__response")
|
||||
|
||||
s.state.PushValue(s.sessionToTable(reqCtx.Session))
|
||||
s.state.SetGlobal("__session")
|
||||
// Session data
|
||||
if !reqCtx.session.IsEmpty() {
|
||||
state.PushValue(reqCtx.session.GetAll())
|
||||
state.SetGlobal("__session")
|
||||
}
|
||||
}
|
||||
|
||||
s.state.NewTable()
|
||||
s.state.SetGlobal("__response")
|
||||
func (s *Server) setupFastEnvironment(state *luajit.State, reqCtx *RequestContext) {
|
||||
// Request basics as globals for fast access
|
||||
state.PushString(string(reqCtx.ctx.Method()))
|
||||
state.SetGlobal("REQUEST_METHOD")
|
||||
|
||||
state.PushString(string(reqCtx.ctx.Path()))
|
||||
state.SetGlobal("REQUEST_PATH")
|
||||
|
||||
// Parameters
|
||||
if reqCtx.params != nil && len(reqCtx.params.Keys) > 0 {
|
||||
paramMap := make(map[string]string, len(reqCtx.params.Keys))
|
||||
for i, key := range reqCtx.params.Keys {
|
||||
if i < len(reqCtx.params.Values) {
|
||||
paramMap[key] = reqCtx.params.Values[i]
|
||||
}
|
||||
}
|
||||
state.PushValue(paramMap)
|
||||
state.SetGlobal("PARAMS")
|
||||
}
|
||||
|
||||
// Query parameters
|
||||
queryMap := make(map[string]string)
|
||||
reqCtx.ctx.QueryArgs().VisitAll(func(key, value []byte) {
|
||||
queryMap[string(key)] = string(value)
|
||||
})
|
||||
if len(queryMap) > 0 {
|
||||
state.PushValue(queryMap)
|
||||
state.SetGlobal("QUERY")
|
||||
}
|
||||
|
||||
// Headers
|
||||
headerMap := make(map[string]string)
|
||||
reqCtx.ctx.Request.Header.VisitAll(func(key, value []byte) {
|
||||
headerMap[string(key)] = string(value)
|
||||
})
|
||||
state.PushValue(headerMap)
|
||||
state.SetGlobal("HEADERS")
|
||||
|
||||
// Cookies
|
||||
cookieMap := make(map[string]string)
|
||||
reqCtx.ctx.Request.Header.VisitAllCookie(func(key, value []byte) {
|
||||
cookieMap[string(key)] = string(value)
|
||||
})
|
||||
if len(cookieMap) > 0 {
|
||||
state.PushValue(cookieMap)
|
||||
state.SetGlobal("COOKIES")
|
||||
}
|
||||
|
||||
// Form data
|
||||
if reqCtx.ctx.IsPost() || reqCtx.ctx.IsPut() || reqCtx.ctx.IsPatch() {
|
||||
form := s.parseForm(reqCtx.ctx)
|
||||
if len(form) > 0 {
|
||||
state.PushValue(form)
|
||||
state.SetGlobal("FORM")
|
||||
}
|
||||
}
|
||||
|
||||
// Session data
|
||||
if !reqCtx.session.IsEmpty() {
|
||||
state.PushValue(reqCtx.session.GetAll())
|
||||
state.SetGlobal("session_data")
|
||||
}
|
||||
|
||||
// CSRF token
|
||||
if csrfToken := s.generateCSRFToken(); csrfToken != "" {
|
||||
state.PushString(csrfToken)
|
||||
state.SetGlobal("CSRF_TOKEN")
|
||||
}
|
||||
|
||||
// JSON encode fallback
|
||||
state.RegisterGoFunction("json_encode_fallback", func(state *luajit.State) int {
|
||||
val, _ := state.ToValue(1)
|
||||
if b, err := json.Marshal(val); err == nil {
|
||||
state.PushString(string(b))
|
||||
} else {
|
||||
state.PushString("null")
|
||||
}
|
||||
return 1
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) requestToTable(reqCtx *RequestContext) map[string]any {
|
||||
return map[string]any{
|
||||
"method": reqCtx.Method,
|
||||
"path": reqCtx.Path,
|
||||
"headers": reqCtx.Headers,
|
||||
"query": reqCtx.Query,
|
||||
"form": reqCtx.Form,
|
||||
"cookies": reqCtx.Cookies,
|
||||
"body": reqCtx.Body,
|
||||
"params": reqCtx.Params,
|
||||
req := map[string]any{
|
||||
"method": string(reqCtx.ctx.Method()),
|
||||
"path": string(reqCtx.ctx.Path()),
|
||||
"headers": make(map[string]string),
|
||||
"query": make(map[string]string),
|
||||
"cookies": make(map[string]string),
|
||||
"body": string(reqCtx.ctx.PostBody()),
|
||||
}
|
||||
|
||||
// Headers
|
||||
headers := req["headers"].(map[string]string)
|
||||
reqCtx.ctx.Request.Header.VisitAll(func(key, value []byte) {
|
||||
headers[string(key)] = string(value)
|
||||
})
|
||||
|
||||
// Cookies
|
||||
cookies := req["cookies"].(map[string]string)
|
||||
reqCtx.ctx.Request.Header.VisitAllCookie(func(key, value []byte) {
|
||||
cookies[string(key)] = string(value)
|
||||
})
|
||||
|
||||
// Query
|
||||
query := req["query"].(map[string]string)
|
||||
reqCtx.ctx.QueryArgs().VisitAll(func(key, value []byte) {
|
||||
query[string(key)] = string(value)
|
||||
})
|
||||
|
||||
// Params
|
||||
if reqCtx.params != nil && len(reqCtx.params.Keys) > 0 {
|
||||
params := make(map[string]string, len(reqCtx.params.Keys))
|
||||
for i, key := range reqCtx.params.Keys {
|
||||
if i < len(reqCtx.params.Values) {
|
||||
params[key] = reqCtx.params.Values[i]
|
||||
}
|
||||
}
|
||||
req["params"] = params
|
||||
}
|
||||
|
||||
// Form
|
||||
if reqCtx.ctx.IsPost() || reqCtx.ctx.IsPut() || reqCtx.ctx.IsPatch() {
|
||||
req["form"] = s.parseForm(reqCtx.ctx)
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func (s *Server) sessionToTable(session *sessions.Session) map[string]any {
|
||||
return map[string]any{
|
||||
"id": session.ID,
|
||||
"data": session.GetAll(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) updateSessionFromLua(session *sessions.Session) {
|
||||
s.state.GetGlobal("__session")
|
||||
if s.state.IsNil(-1) {
|
||||
s.state.Pop(1)
|
||||
return
|
||||
}
|
||||
|
||||
s.state.GetField(-1, "data")
|
||||
if s.state.IsTable(-1) {
|
||||
if data, err := s.state.ToTable(-1); err == nil {
|
||||
func (s *Server) applyResponse(ctx *fasthttp.RequestCtx, state *luajit.State, body string, session *sessions.Session) {
|
||||
// Update session from Lua
|
||||
state.GetGlobal("session_data")
|
||||
if state.IsTable(-1) {
|
||||
if data, err := state.ToTable(-1); err == nil {
|
||||
if dataMap, ok := data.(map[string]any); ok {
|
||||
session.Clear()
|
||||
for k, v := range dataMap {
|
||||
@ -266,55 +446,59 @@ func (s *Server) updateSessionFromLua(session *sessions.Session) {
|
||||
}
|
||||
}
|
||||
}
|
||||
s.state.Pop(2)
|
||||
}
|
||||
state.Pop(1)
|
||||
|
||||
func (s *Server) applyResponse(ctx *fasthttp.RequestCtx, body string) {
|
||||
s.state.GetGlobal("__response")
|
||||
if s.state.IsNil(-1) {
|
||||
s.state.Pop(1)
|
||||
if body != "" {
|
||||
ctx.SetBodyString(body)
|
||||
// Check for response table (function handlers) or response global (fast handlers)
|
||||
state.GetGlobal("__response")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.GetGlobal("response")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
if body != "" {
|
||||
ctx.SetBodyString(body)
|
||||
}
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
s.state.GetField(-1, "status")
|
||||
if s.state.IsNumber(-1) {
|
||||
ctx.SetStatusCode(int(s.state.ToNumber(-1)))
|
||||
// Status
|
||||
if status := state.GetFieldNumber(-1, "status", 200); status != 200 {
|
||||
ctx.SetStatusCode(int(status))
|
||||
}
|
||||
s.state.Pop(1)
|
||||
|
||||
s.state.GetField(-1, "headers")
|
||||
if s.state.IsTable(-1) {
|
||||
s.state.ForEachTableKV(-1, func(key, value string) bool {
|
||||
// Headers
|
||||
state.GetField(-1, "headers")
|
||||
if state.IsTable(-1) {
|
||||
state.ForEachTableKV(-1, func(key, value string) bool {
|
||||
ctx.Response.Header.Set(key, value)
|
||||
return true
|
||||
})
|
||||
}
|
||||
s.state.Pop(1)
|
||||
state.Pop(1)
|
||||
|
||||
s.state.GetField(-1, "cookies")
|
||||
if s.state.IsTable(-1) {
|
||||
s.applyCookies(ctx)
|
||||
// Cookies
|
||||
state.GetField(-1, "cookies")
|
||||
if state.IsTable(-1) {
|
||||
s.applyCookies(ctx, state)
|
||||
}
|
||||
s.state.Pop(1)
|
||||
state.Pop(1)
|
||||
|
||||
s.state.Pop(1)
|
||||
state.Pop(1)
|
||||
|
||||
if body != "" {
|
||||
ctx.SetBodyString(body)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) applyCookies(ctx *fasthttp.RequestCtx) {
|
||||
s.state.ForEachArray(-1, func(i int, state *luajit.State) bool {
|
||||
if !state.IsTable(-1) {
|
||||
func (s *Server) applyCookies(ctx *fasthttp.RequestCtx, state *luajit.State) {
|
||||
state.ForEachArray(-1, func(i int, st *luajit.State) bool {
|
||||
if !st.IsTable(-1) {
|
||||
return true
|
||||
}
|
||||
|
||||
name := state.GetFieldString(-1, "name", "")
|
||||
value := state.GetFieldString(-1, "value", "")
|
||||
name := st.GetFieldString(-1, "name", "")
|
||||
value := st.GetFieldString(-1, "value", "")
|
||||
if name == "" {
|
||||
return true
|
||||
}
|
||||
@ -324,22 +508,29 @@ func (s *Server) applyCookies(ctx *fasthttp.RequestCtx) {
|
||||
|
||||
cookie.SetKey(name)
|
||||
cookie.SetValue(value)
|
||||
cookie.SetPath(state.GetFieldString(-1, "path", "/"))
|
||||
|
||||
if domain := state.GetFieldString(-1, "domain", ""); domain != "" {
|
||||
cookie.SetDomain(domain)
|
||||
}
|
||||
|
||||
if state.GetFieldBool(-1, "secure", false) {
|
||||
cookie.SetSecure(true)
|
||||
}
|
||||
|
||||
if state.GetFieldBool(-1, "http_only", true) {
|
||||
cookie.SetHTTPOnly(true)
|
||||
}
|
||||
|
||||
if maxAge := state.GetFieldNumber(-1, "max_age", 0); maxAge > 0 {
|
||||
cookie.SetExpire(time.Now().Add(time.Duration(maxAge) * time.Second))
|
||||
if options, ok := st.GetFieldTable(-1, "options"); ok {
|
||||
if optMap, ok := options.(map[string]any); ok {
|
||||
if path, ok := optMap["path"].(string); ok {
|
||||
cookie.SetPath(path)
|
||||
} else {
|
||||
cookie.SetPath("/")
|
||||
}
|
||||
if domain, ok := optMap["domain"].(string); ok {
|
||||
cookie.SetDomain(domain)
|
||||
}
|
||||
if secure, ok := optMap["secure"].(bool); ok && secure {
|
||||
cookie.SetSecure(true)
|
||||
}
|
||||
if httpOnly, ok := optMap["http_only"].(bool); ok {
|
||||
cookie.SetHTTPOnly(httpOnly)
|
||||
} else {
|
||||
cookie.SetHTTPOnly(true)
|
||||
}
|
||||
if maxAge, ok := optMap["max_age"].(int); ok && maxAge > 0 {
|
||||
cookie.SetExpire(time.Now().Add(time.Duration(maxAge) * time.Second))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
@ -347,42 +538,6 @@ func (s *Server) applyCookies(ctx *fasthttp.RequestCtx) {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) buildRequestContext(ctx *fasthttp.RequestCtx, params *router.Params) *RequestContext {
|
||||
reqCtx := &RequestContext{
|
||||
Method: string(ctx.Method()),
|
||||
Path: string(ctx.Path()),
|
||||
Headers: make(map[string]string),
|
||||
Query: make(map[string]string),
|
||||
Cookies: make(map[string]string),
|
||||
Body: string(ctx.PostBody()),
|
||||
Params: make(map[string]string),
|
||||
}
|
||||
|
||||
ctx.Request.Header.VisitAll(func(key, value []byte) {
|
||||
reqCtx.Headers[string(key)] = string(value)
|
||||
})
|
||||
|
||||
ctx.Request.Header.VisitAllCookie(func(key, value []byte) {
|
||||
reqCtx.Cookies[string(key)] = string(value)
|
||||
})
|
||||
|
||||
ctx.QueryArgs().VisitAll(func(key, value []byte) {
|
||||
reqCtx.Query[string(key)] = string(value)
|
||||
})
|
||||
|
||||
// Convert router params to map
|
||||
for i, key := range params.Keys {
|
||||
if i < len(params.Values) {
|
||||
reqCtx.Params[key] = params.Values[i]
|
||||
}
|
||||
}
|
||||
|
||||
reqCtx.Form = s.parseForm(ctx)
|
||||
reqCtx.Session = s.sessions.GetSessionFromRequest(ctx)
|
||||
|
||||
return reqCtx
|
||||
}
|
||||
|
||||
func (s *Server) parseForm(ctx *fasthttp.RequestCtx) map[string]any {
|
||||
contentType := string(ctx.Request.Header.ContentType())
|
||||
form := make(map[string]any)
|
||||
@ -440,317 +595,3 @@ func (s *Server) generateCSRFToken() string {
|
||||
rand.Read(bytes)
|
||||
return base64.URLEncoding.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// Lua function implementations
|
||||
func httpSetStatus(state *luajit.State) int {
|
||||
code, _ := state.SafeToNumber(1)
|
||||
state.GetGlobal("__response")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.SetGlobal("__response")
|
||||
state.GetGlobal("__response")
|
||||
}
|
||||
state.PushNumber(code)
|
||||
state.SetField(-2, "status")
|
||||
state.Pop(1)
|
||||
return 0
|
||||
}
|
||||
|
||||
func httpSetHeader(state *luajit.State) int {
|
||||
name, _ := state.SafeToString(1)
|
||||
value, _ := state.SafeToString(2)
|
||||
|
||||
state.GetGlobal("__response")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.SetGlobal("__response")
|
||||
state.GetGlobal("__response")
|
||||
}
|
||||
|
||||
state.GetField(-1, "headers")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.PushCopy(-1)
|
||||
state.SetField(-3, "headers")
|
||||
}
|
||||
|
||||
state.PushString(value)
|
||||
state.SetField(-2, name)
|
||||
state.Pop(2)
|
||||
return 0
|
||||
}
|
||||
|
||||
func httpRedirect(state *luajit.State) int {
|
||||
url, _ := state.SafeToString(1)
|
||||
status := 302.0
|
||||
if state.GetTop() >= 2 {
|
||||
status, _ = state.SafeToNumber(2)
|
||||
}
|
||||
|
||||
state.GetGlobal("__response")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.SetGlobal("__response")
|
||||
state.GetGlobal("__response")
|
||||
}
|
||||
|
||||
state.PushNumber(status)
|
||||
state.SetField(-2, "status")
|
||||
|
||||
state.GetField(-1, "headers")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.PushCopy(-1)
|
||||
state.SetField(-3, "headers")
|
||||
}
|
||||
|
||||
state.PushString(url)
|
||||
state.SetField(-2, "Location")
|
||||
state.Pop(2)
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *Server) sessionGet(state *luajit.State) int {
|
||||
key, _ := state.SafeToString(1)
|
||||
|
||||
state.GetGlobal("__session")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
state.GetField(-1, "data")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(2)
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
state.GetField(-1, key)
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s *Server) sessionSet(state *luajit.State) int {
|
||||
key, _ := state.SafeToString(1)
|
||||
|
||||
state.GetGlobal("__session")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.SetGlobal("__session")
|
||||
state.GetGlobal("__session")
|
||||
}
|
||||
|
||||
state.GetField(-1, "data")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.PushCopy(-1)
|
||||
state.SetField(-3, "data")
|
||||
}
|
||||
|
||||
value, err := state.ToValue(2)
|
||||
if err == nil {
|
||||
state.PushValue(value)
|
||||
state.SetField(-2, key)
|
||||
}
|
||||
state.Pop(2)
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *Server) sessionFlash(state *luajit.State) int {
|
||||
key, _ := state.SafeToString(1)
|
||||
|
||||
state.GetGlobal("__session")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.SetGlobal("__session")
|
||||
state.GetGlobal("__session")
|
||||
}
|
||||
|
||||
state.GetField(-1, "flash")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.PushCopy(-1)
|
||||
state.SetField(-3, "flash")
|
||||
}
|
||||
|
||||
value, err := state.ToValue(2)
|
||||
if err == nil {
|
||||
state.PushValue(value)
|
||||
state.SetField(-2, key)
|
||||
}
|
||||
state.Pop(2)
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *Server) sessionGetFlash(state *luajit.State) int {
|
||||
key, _ := state.SafeToString(1)
|
||||
|
||||
state.GetGlobal("__session")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
state.GetField(-1, "flash")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(2)
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
state.GetField(-1, key)
|
||||
return 1
|
||||
}
|
||||
|
||||
func cookieSet(state *luajit.State) int {
|
||||
name, _ := state.SafeToString(1)
|
||||
value, _ := state.SafeToString(2)
|
||||
|
||||
maxAge := 0
|
||||
path := "/"
|
||||
domain := ""
|
||||
secure := false
|
||||
httpOnly := true
|
||||
|
||||
if state.GetTop() >= 3 && state.IsTable(3) {
|
||||
maxAge = int(state.GetFieldNumber(3, "max_age", 0))
|
||||
path = state.GetFieldString(3, "path", "/")
|
||||
domain = state.GetFieldString(3, "domain", "")
|
||||
secure = state.GetFieldBool(3, "secure", false)
|
||||
httpOnly = state.GetFieldBool(3, "http_only", true)
|
||||
}
|
||||
|
||||
state.GetGlobal("__response")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.SetGlobal("__response")
|
||||
state.GetGlobal("__response")
|
||||
}
|
||||
|
||||
state.GetField(-1, "cookies")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.PushCopy(-1)
|
||||
state.SetField(-3, "cookies")
|
||||
}
|
||||
|
||||
cookieData := map[string]any{
|
||||
"name": name,
|
||||
"value": value,
|
||||
"path": path,
|
||||
"secure": secure,
|
||||
"http_only": httpOnly,
|
||||
}
|
||||
if domain != "" {
|
||||
cookieData["domain"] = domain
|
||||
}
|
||||
if maxAge > 0 {
|
||||
cookieData["max_age"] = maxAge
|
||||
}
|
||||
|
||||
state.PushValue(cookieData)
|
||||
|
||||
length := globalServer.getTableLength(-2)
|
||||
state.PushNumber(float64(length + 1))
|
||||
state.PushCopy(-2)
|
||||
state.SetTable(-4)
|
||||
|
||||
state.Pop(3)
|
||||
return 0
|
||||
}
|
||||
|
||||
func cookieGet(state *luajit.State) int {
|
||||
name, _ := state.SafeToString(1)
|
||||
|
||||
state.GetGlobal("__request")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
state.GetField(-1, "cookies")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(2)
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
state.GetField(-1, name)
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s *Server) csrfGenerate(state *luajit.State) int {
|
||||
token := s.generateCSRFToken()
|
||||
|
||||
state.GetGlobal("__session")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.SetGlobal("__session")
|
||||
state.GetGlobal("__session")
|
||||
}
|
||||
|
||||
state.GetField(-1, "data")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.PushCopy(-1)
|
||||
state.SetField(-3, "data")
|
||||
}
|
||||
|
||||
state.PushString(token)
|
||||
state.SetField(-2, "_csrf_token")
|
||||
state.Pop(2)
|
||||
|
||||
state.PushString(token)
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s *Server) csrfValidate(state *luajit.State) int {
|
||||
state.GetGlobal("__session")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
sessionToken := state.GetFieldString(-1, "data._csrf_token", "")
|
||||
state.Pop(1)
|
||||
|
||||
state.GetGlobal("__request")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
requestToken := state.GetFieldString(-1, "form._csrf_token", "")
|
||||
state.Pop(1)
|
||||
|
||||
state.PushBoolean(sessionToken != "" && sessionToken == requestToken)
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s *Server) getTableLength(index int) int {
|
||||
length := 0
|
||||
s.state.PushNil()
|
||||
for s.state.Next(index - 1) {
|
||||
length++
|
||||
s.state.Pop(1)
|
||||
}
|
||||
return length
|
||||
}
|
||||
|
144
http/http.lua
144
http/http.lua
@ -1,3 +1,7 @@
|
||||
-- Fast response handling
|
||||
local response = {status = 200, headers = {}, cookies = {}}
|
||||
local session_data = {}
|
||||
|
||||
http = {}
|
||||
|
||||
function http.listen(port)
|
||||
@ -8,20 +12,15 @@ function http.route(method, path, handler)
|
||||
return __http_route(method, path, handler)
|
||||
end
|
||||
|
||||
function http.status(code)
|
||||
return __http_set_status(code)
|
||||
function http.status(code)
|
||||
response.status = code
|
||||
end
|
||||
|
||||
function http.header(name, value)
|
||||
return __http_set_header(name, value)
|
||||
function http.header(k, v)
|
||||
response.headers[k] = v
|
||||
end
|
||||
|
||||
function http.redirect(url, status)
|
||||
__http_redirect(url, status or 302)
|
||||
coroutine.yield() -- Exit handler
|
||||
end
|
||||
|
||||
function http.json(data)
|
||||
function http.json(data)
|
||||
http.header("Content-Type", "application/json")
|
||||
return json.encode(data)
|
||||
end
|
||||
@ -32,84 +31,119 @@ function http.html(content)
|
||||
end
|
||||
|
||||
function http.text(content)
|
||||
http.header("Content-Type", "text/plain")
|
||||
http.header("Content-Type", "text/plain")
|
||||
return content
|
||||
end
|
||||
|
||||
function http.redirect(url, code)
|
||||
response.status = code or 302
|
||||
response.headers["Location"] = url
|
||||
coroutine.yield()
|
||||
end
|
||||
|
||||
-- Session functions
|
||||
session = {}
|
||||
|
||||
function session.get(key)
|
||||
return __session_get(key)
|
||||
function session.get(key)
|
||||
return session_data[key]
|
||||
end
|
||||
|
||||
function session.set(key, value)
|
||||
return __session_set(key, value)
|
||||
function session.set(key, val)
|
||||
session_data[key] = val
|
||||
end
|
||||
|
||||
function session.flash(key, value)
|
||||
return __session_flash(key, value)
|
||||
function session.flash(key, val)
|
||||
session_data["_flash_" .. key] = val
|
||||
end
|
||||
|
||||
function session.get_flash(key)
|
||||
return __session_get_flash(key)
|
||||
function session.get_flash(key)
|
||||
local val = session_data["_flash_" .. key]
|
||||
session_data["_flash_" .. key] = nil
|
||||
return val
|
||||
end
|
||||
|
||||
-- Cookie functions
|
||||
-- Cookie functions
|
||||
cookie = {}
|
||||
|
||||
function cookie.set(name, value, options)
|
||||
return __cookie_set(name, value, options)
|
||||
function cookie.get(name)
|
||||
return COOKIES and COOKIES[name]
|
||||
end
|
||||
|
||||
function cookie.get(name)
|
||||
return __cookie_get(name)
|
||||
function cookie.set(name, value, options)
|
||||
response.cookies[#response.cookies + 1] = {
|
||||
name = name,
|
||||
value = value,
|
||||
options = options or {}
|
||||
}
|
||||
end
|
||||
|
||||
-- CSRF functions
|
||||
csrf = {}
|
||||
|
||||
function csrf.generate()
|
||||
return __csrf_generate()
|
||||
local token = CSRF_TOKEN or ""
|
||||
session.set("_csrf_token", token)
|
||||
return token
|
||||
end
|
||||
|
||||
function csrf.validate()
|
||||
return __csrf_validate()
|
||||
local session_token = session.get("_csrf_token")
|
||||
local form_token = FORM and FORM._csrf_token
|
||||
return session_token and session_token == form_token
|
||||
end
|
||||
|
||||
function csrf.field()
|
||||
local token = csrf.generate()
|
||||
return string.format('<input type="hidden" name="_csrf_token" value="%s" />', token)
|
||||
return '<input type="hidden" name="_csrf_token" value="' .. csrf.generate() .. '" />'
|
||||
end
|
||||
|
||||
-- Fast JSON encoding
|
||||
json = {
|
||||
encode = function(data)
|
||||
if type(data) == "string" then
|
||||
return '"' .. data .. '"'
|
||||
elseif type(data) == "number" then
|
||||
return tostring(data)
|
||||
elseif type(data) == "boolean" then
|
||||
return data and "true" or "false"
|
||||
elseif data == nil then
|
||||
return "null"
|
||||
elseif type(data) == "table" then
|
||||
-- Check if it's an array
|
||||
local isArray = true
|
||||
local n = 0
|
||||
for k, v in pairs(data) do
|
||||
n = n + 1
|
||||
if type(k) ~= "number" or k ~= n then
|
||||
isArray = false
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if isArray then
|
||||
local result = "["
|
||||
for i = 1, n do
|
||||
if i > 1 then result = result .. "," end
|
||||
result = result .. json.encode(data[i])
|
||||
end
|
||||
return result .. "]"
|
||||
else
|
||||
local result = "{"
|
||||
local first = true
|
||||
for k, v in pairs(data) do
|
||||
if not first then result = result .. "," end
|
||||
result = result .. '"' .. tostring(k) .. '":' .. json.encode(v)
|
||||
first = false
|
||||
end
|
||||
return result .. "}"
|
||||
end
|
||||
else
|
||||
return json_encode_fallback(data)
|
||||
end
|
||||
end
|
||||
}
|
||||
|
||||
-- Helper functions
|
||||
function redirect_with_flash(url, type, message)
|
||||
session.flash(type, message)
|
||||
http.redirect(url)
|
||||
end
|
||||
|
||||
-- JSON encoding/decoding placeholder
|
||||
json = {
|
||||
encode = function(data)
|
||||
-- Simplified JSON encoding
|
||||
if type(data) == "table" then
|
||||
local result = "{"
|
||||
local first = true
|
||||
for k, v in pairs(data) do
|
||||
if not first then result = result .. "," end
|
||||
result = result .. '"' .. tostring(k) .. '":' .. json.encode(v)
|
||||
first = false
|
||||
end
|
||||
return result .. "}"
|
||||
elseif type(data) == "string" then
|
||||
return '"' .. data .. '"'
|
||||
else
|
||||
return tostring(data)
|
||||
end
|
||||
end,
|
||||
|
||||
decode = function(str)
|
||||
-- Simplified JSON decoding - you'd want a proper implementation
|
||||
return {}
|
||||
end
|
||||
}
|
||||
end
|
@ -1,25 +1,24 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// node represents a node in the radix trie
|
||||
// Handler function that takes parameters as strings
|
||||
type Handler func(params []string)
|
||||
|
||||
type node struct {
|
||||
segment string
|
||||
handler int // Lua function reference
|
||||
handlerID int
|
||||
children []*node
|
||||
isDynamic bool // :param
|
||||
isWildcard bool // *param
|
||||
paramName string
|
||||
isDynamic bool
|
||||
isWildcard bool
|
||||
maxParams uint8
|
||||
}
|
||||
|
||||
// Router is a string-based HTTP router with efficient lookup
|
||||
type Router struct {
|
||||
get, post, put, patch, delete *node
|
||||
mu sync.RWMutex
|
||||
paramsBuffer []string
|
||||
}
|
||||
|
||||
// Params holds URL parameters
|
||||
@ -41,15 +40,15 @@ func (p *Params) Get(name string) string {
|
||||
// New creates a new Router instance
|
||||
func New() *Router {
|
||||
return &Router{
|
||||
get: &node{},
|
||||
post: &node{},
|
||||
put: &node{},
|
||||
patch: &node{},
|
||||
delete: &node{},
|
||||
get: &node{},
|
||||
post: &node{},
|
||||
put: &node{},
|
||||
patch: &node{},
|
||||
delete: &node{},
|
||||
paramsBuffer: make([]string, 64),
|
||||
}
|
||||
}
|
||||
|
||||
// methodNode returns the root node for a method
|
||||
func (r *Router) methodNode(method string) *node {
|
||||
switch method {
|
||||
case "GET":
|
||||
@ -67,47 +66,71 @@ func (r *Router) methodNode(method string) *node {
|
||||
}
|
||||
}
|
||||
|
||||
// AddRoute adds a new route with handler reference
|
||||
func (r *Router) AddRoute(method, path string, handlerRef int) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// AddRoute adds a route with handler ID (for compatibility)
|
||||
func (r *Router) AddRoute(method, path string, handlerID int) error {
|
||||
root := r.methodNode(method)
|
||||
if root == nil {
|
||||
return errors.New("unsupported HTTP method")
|
||||
return fmt.Errorf("unsupported method: %s", method)
|
||||
}
|
||||
|
||||
// Create a handler that stores the ID
|
||||
h := func(params []string) {
|
||||
// This is a placeholder - the actual execution happens in HTTP module
|
||||
}
|
||||
|
||||
return r.addRoute(root, path, h, handlerID)
|
||||
}
|
||||
|
||||
// readSegment extracts the next path segment
|
||||
func readSegment(path string, start int) (segment string, end int, hasMore bool) {
|
||||
if start >= len(path) {
|
||||
return "", start, false
|
||||
}
|
||||
if path[start] == '/' {
|
||||
start++
|
||||
}
|
||||
if start >= len(path) {
|
||||
return "", start, false
|
||||
}
|
||||
end = start
|
||||
for end < len(path) && path[end] != '/' {
|
||||
end++
|
||||
}
|
||||
return path[start:end], end, end < len(path)
|
||||
}
|
||||
|
||||
// addRoute adds a new route to the trie
|
||||
func (r *Router) addRoute(root *node, path string, h Handler, handlerID int) error {
|
||||
if path == "/" {
|
||||
root.handler = handlerRef
|
||||
root.handlerID = handlerID
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.addRoute(root, path, handlerRef)
|
||||
}
|
||||
|
||||
// addRoute adds a route to the trie
|
||||
func (r *Router) addRoute(root *node, path string, handlerRef int) error {
|
||||
segments := r.parseSegments(path)
|
||||
current := root
|
||||
pos := 0
|
||||
lastWC := false
|
||||
count := uint8(0)
|
||||
|
||||
for _, seg := range segments {
|
||||
isDyn := strings.HasPrefix(seg, ":")
|
||||
isWC := strings.HasPrefix(seg, "*")
|
||||
|
||||
if isWC && seg != segments[len(segments)-1] {
|
||||
return errors.New("wildcard must be the last segment")
|
||||
for {
|
||||
seg, newPos, more := readSegment(path, pos)
|
||||
if seg == "" {
|
||||
break
|
||||
}
|
||||
|
||||
paramName := ""
|
||||
if isDyn {
|
||||
paramName = seg[1:]
|
||||
seg = ":"
|
||||
} else if isWC {
|
||||
paramName = seg[1:]
|
||||
seg = "*"
|
||||
isDyn := len(seg) > 1 && seg[0] == ':'
|
||||
isWC := len(seg) > 0 && seg[0] == '*'
|
||||
|
||||
if isWC {
|
||||
if lastWC || more {
|
||||
return fmt.Errorf("wildcard must be the last segment in the path")
|
||||
}
|
||||
lastWC = true
|
||||
}
|
||||
|
||||
if isDyn || isWC {
|
||||
count++
|
||||
}
|
||||
|
||||
// Find or create child
|
||||
var child *node
|
||||
for _, c := range current.children {
|
||||
if c.segment == seg {
|
||||
@ -117,144 +140,102 @@ func (r *Router) addRoute(root *node, path string, handlerRef int) error {
|
||||
}
|
||||
|
||||
if child == nil {
|
||||
child = &node{
|
||||
segment: seg,
|
||||
isDynamic: isDyn,
|
||||
isWildcard: isWC,
|
||||
paramName: paramName,
|
||||
}
|
||||
child = &node{segment: seg, isDynamic: isDyn, isWildcard: isWC}
|
||||
current.children = append(current.children, child)
|
||||
}
|
||||
|
||||
if child.maxParams < count {
|
||||
child.maxParams = count
|
||||
}
|
||||
|
||||
current = child
|
||||
pos = newPos
|
||||
}
|
||||
|
||||
current.handler = handlerRef
|
||||
current.handlerID = handlerID
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseSegments splits path into segments
|
||||
func (r *Router) parseSegments(path string) []string {
|
||||
segments := strings.Split(strings.Trim(path, "/"), "/")
|
||||
var result []string
|
||||
for _, seg := range segments {
|
||||
if seg != "" {
|
||||
result = append(result, seg)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Lookup finds handler and parameters for a method and path
|
||||
// Lookup finds a handler matching method and path
|
||||
func (r *Router) Lookup(method, path string) (int, *Params, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
root := r.methodNode(method)
|
||||
if root == nil {
|
||||
return 0, nil, false
|
||||
}
|
||||
|
||||
if path == "/" {
|
||||
if root.handler != 0 {
|
||||
return root.handler, &Params{}, true
|
||||
if root.handlerID != 0 {
|
||||
return root.handlerID, &Params{}, true
|
||||
}
|
||||
return 0, nil, false
|
||||
}
|
||||
|
||||
segments := r.parseSegments(path)
|
||||
handler, params := r.match(root, segments, 0)
|
||||
if handler == 0 {
|
||||
buffer := r.paramsBuffer
|
||||
if cap(buffer) < int(root.maxParams) {
|
||||
buffer = make([]string, root.maxParams)
|
||||
r.paramsBuffer = buffer
|
||||
}
|
||||
buffer = buffer[:0]
|
||||
|
||||
handlerID, paramCount, paramKeys, found := r.match(root, path, 0, &buffer)
|
||||
if !found {
|
||||
return 0, nil, false
|
||||
}
|
||||
|
||||
return handler, params, true
|
||||
}
|
||||
|
||||
// match traverses the trie to find handler
|
||||
func (r *Router) match(current *node, segments []string, index int) (int, *Params) {
|
||||
if index >= len(segments) {
|
||||
if current.handler != 0 {
|
||||
return current.handler, &Params{}
|
||||
}
|
||||
return 0, nil
|
||||
params := &Params{
|
||||
Keys: paramKeys,
|
||||
Values: buffer[:paramCount],
|
||||
}
|
||||
|
||||
segment := segments[index]
|
||||
return handlerID, params, true
|
||||
}
|
||||
|
||||
// Check exact match first
|
||||
for _, child := range current.children {
|
||||
if child.segment == segment {
|
||||
handler, params := r.match(child, segments, index+1)
|
||||
if handler != 0 {
|
||||
return handler, params
|
||||
// match traverses the trie to find a handler
|
||||
func (r *Router) match(current *node, path string, start int, params *[]string) (int, int, []string, bool) {
|
||||
paramCount := 0
|
||||
var paramKeys []string
|
||||
|
||||
// Check wildcards first
|
||||
for _, c := range current.children {
|
||||
if c.isWildcard {
|
||||
rem := path[start:]
|
||||
if len(rem) > 0 && rem[0] == '/' {
|
||||
rem = rem[1:]
|
||||
}
|
||||
*params = append(*params, rem)
|
||||
// Extract param name from *name format
|
||||
paramName := c.segment[1:] // Remove the * prefix
|
||||
paramKeys = append(paramKeys, paramName)
|
||||
return c.handlerID, 1, paramKeys, c.handlerID != 0
|
||||
}
|
||||
}
|
||||
|
||||
seg, pos, more := readSegment(path, start)
|
||||
if seg == "" {
|
||||
return current.handlerID, 0, paramKeys, current.handlerID != 0
|
||||
}
|
||||
|
||||
for _, c := range current.children {
|
||||
if c.segment == seg || c.isDynamic {
|
||||
if c.isDynamic {
|
||||
*params = append(*params, seg)
|
||||
// Extract param name from :name format
|
||||
paramName := c.segment[1:] // Remove the : prefix
|
||||
paramKeys = append(paramKeys, paramName)
|
||||
paramCount++
|
||||
}
|
||||
|
||||
if !more {
|
||||
return c.handlerID, paramCount, paramKeys, c.handlerID != 0
|
||||
}
|
||||
|
||||
handlerID, nestedCount, nestedKeys, ok := r.match(c, path, pos, params)
|
||||
if ok {
|
||||
allKeys := append(paramKeys, nestedKeys...)
|
||||
return handlerID, paramCount + nestedCount, allKeys, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check dynamic match second
|
||||
for _, child := range current.children {
|
||||
if child.isDynamic {
|
||||
handler, params := r.match(child, segments, index+1)
|
||||
if handler != 0 {
|
||||
// Prepend this parameter
|
||||
newParams := &Params{
|
||||
Keys: append([]string{child.paramName}, params.Keys...),
|
||||
Values: append([]string{segment}, params.Values...),
|
||||
}
|
||||
return handler, newParams
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check wildcard last (catches everything remaining)
|
||||
for _, child := range current.children {
|
||||
if child.isWildcard {
|
||||
remaining := strings.Join(segments[index:], "/")
|
||||
return child.handler, &Params{
|
||||
Keys: []string{child.paramName},
|
||||
Values: []string{remaining},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// RemoveRoute removes a route
|
||||
func (r *Router) RemoveRoute(method, path string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
root := r.methodNode(method)
|
||||
if root == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if path == "/" {
|
||||
root.handler = 0
|
||||
return
|
||||
}
|
||||
|
||||
segments := r.parseSegments(path)
|
||||
r.removeRoute(root, segments, 0)
|
||||
}
|
||||
|
||||
// removeRoute removes a route from the trie
|
||||
func (r *Router) removeRoute(current *node, segments []string, index int) {
|
||||
if index >= len(segments) {
|
||||
current.handler = 0
|
||||
return
|
||||
}
|
||||
|
||||
segment := segments[index]
|
||||
|
||||
for _, child := range current.children {
|
||||
if child.segment == segment ||
|
||||
(child.isDynamic && strings.HasPrefix(segment, ":")) ||
|
||||
(child.isWildcard && strings.HasPrefix(segment, "*")) {
|
||||
r.removeRoute(child, segments, index+1)
|
||||
break
|
||||
}
|
||||
}
|
||||
return 0, 0, nil, false
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user