496 lines
11 KiB
Go

package http
import (
"fmt"
"sync"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/valyala/fasthttp"
)
// Handler is a fasthttp request handler with parameters
type Handler func(ctx *fasthttp.RequestCtx, params []string)
type node struct {
segment string
handler Handler
children []*node
isDynamic bool
isWildcard bool
maxParams uint8
}
type Router struct {
get, post, put, patch, delete *node
paramsBuffer []string
}
func newRouter() *Router {
return &Router{
get: &node{},
post: &node{},
put: &node{},
patch: &node{},
delete: &node{},
paramsBuffer: make([]string, 64),
}
}
// HTTPServer with efficient serialized Lua handling
type HTTPServer struct {
server *fasthttp.Server
router *Router
addr string
running bool
mu sync.RWMutex
luaMu sync.Mutex // Serializes Lua calls
}
var (
serverRegistry = struct {
sync.RWMutex
servers map[int]*HTTPServer
nextID int
}{
servers: make(map[int]*HTTPServer),
nextID: 1,
}
)
func GetHTTPFunctions() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"http_create_server": func(s *luajit.State) int {
server := &HTTPServer{
server: &fasthttp.Server{
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 60 * time.Second,
},
router: newRouter(),
}
server.server.Handler = server.handleRequest
serverRegistry.Lock()
id := serverRegistry.nextID
serverRegistry.nextID++
serverRegistry.servers[id] = server
serverRegistry.Unlock()
s.PushNumber(float64(id))
return 1
},
"http_server_listen": func(s *luajit.State) int {
if err := s.CheckExactArgs(2); err != nil {
return s.PushError("http_server_listen: %v", err)
}
serverID, err := s.SafeToNumber(1)
if err != nil || serverID != float64(int(serverID)) {
return s.PushError("http_server_listen: server ID must be an integer")
}
addr, err := s.SafeToString(2)
if err != nil {
return s.PushError("http_server_listen: address must be a string")
}
serverRegistry.RLock()
server, exists := serverRegistry.servers[int(serverID)]
serverRegistry.RUnlock()
if !exists {
return s.PushError("http_server_listen: server not found")
}
server.mu.Lock()
if server.running {
server.mu.Unlock()
return s.PushError("http_server_listen: server already running")
}
server.addr = addr
server.running = true
server.mu.Unlock()
go func() {
err := server.server.ListenAndServe(addr)
if err != nil {
server.mu.Lock()
server.running = false
server.mu.Unlock()
}
}()
s.PushBoolean(true)
return 1
},
"http_server_stop": func(s *luajit.State) int {
if err := s.CheckMinArgs(1); err != nil {
return s.PushError("http_server_stop: %v", err)
}
serverID, err := s.SafeToNumber(1)
if err != nil || serverID != float64(int(serverID)) {
return s.PushError("http_server_stop: server ID must be an integer")
}
serverRegistry.RLock()
server, exists := serverRegistry.servers[int(serverID)]
serverRegistry.RUnlock()
if !exists {
return s.PushError("http_server_stop: server not found")
}
server.mu.Lock()
if !server.running {
server.mu.Unlock()
s.PushBoolean(false)
return 1
}
server.running = false
server.mu.Unlock()
if err := server.server.Shutdown(); err != nil {
return s.PushError("http_server_stop: %v", err)
}
s.PushBoolean(true)
return 1
},
"http_server_get": createRouteHandler("GET"),
"http_server_post": createRouteHandler("POST"),
"http_server_put": createRouteHandler("PUT"),
"http_server_patch": createRouteHandler("PATCH"),
"http_server_delete": createRouteHandler("DELETE"),
"http_server_is_running": func(s *luajit.State) int {
if err := s.CheckMinArgs(1); err != nil {
return s.PushError("http_server_is_running: %v", err)
}
serverID, err := s.SafeToNumber(1)
if err != nil || serverID != float64(int(serverID)) {
return s.PushError("http_server_is_running: server ID must be an integer")
}
serverRegistry.RLock()
server, exists := serverRegistry.servers[int(serverID)]
serverRegistry.RUnlock()
if !exists {
s.PushBoolean(false)
return 1
}
server.mu.RLock()
running := server.running
server.mu.RUnlock()
s.PushBoolean(running)
return 1
},
"http_cleanup_servers": func(s *luajit.State) int {
serverRegistry.Lock()
for id, server := range serverRegistry.servers {
server.mu.Lock()
if server.running {
server.server.Shutdown()
server.running = false
}
server.mu.Unlock()
delete(serverRegistry.servers, id)
}
serverRegistry.Unlock()
s.PushBoolean(true)
return 1
},
}
}
func createRouteHandler(method string) luajit.GoFunction {
return func(s *luajit.State) int {
if err := s.CheckExactArgs(3); err != nil {
return s.PushError("http_server_%s: %v", method, err)
}
serverID, err := s.SafeToNumber(1)
if err != nil || serverID != float64(int(serverID)) {
return s.PushError("http_server_%s: server ID must be an integer", method)
}
path, err := s.SafeToString(2)
if err != nil {
return s.PushError("http_server_%s: path must be a string", method)
}
if !s.IsFunction(3) {
return s.PushError("http_server_%s: handler must be a function", method)
}
serverRegistry.RLock()
server, exists := serverRegistry.servers[int(serverID)]
serverRegistry.RUnlock()
if !exists {
return s.PushError("http_server_%s: server not found", method)
}
luaFunc, err := s.StoreLuaFunction(3)
if err != nil {
return s.PushError("http_server_%s: failed to store function: %v", method, err)
}
handler := func(ctx *fasthttp.RequestCtx, params []string) {
server.callLuaHandler(ctx, params, luaFunc)
}
if err := server.router.addRoute(method, path, handler); err != nil {
return s.PushError("http_server_%s: failed to add route: %v", method, err)
}
s.PushBoolean(true)
return 1
}
}
// Router methods
func (r *Router) methodNode(method string) *node {
switch method {
case "GET":
return r.get
case "POST":
return r.post
case "PUT":
return r.put
case "PATCH":
return r.patch
case "DELETE":
return r.delete
default:
return nil
}
}
func (r *Router) addRoute(method, path string, h Handler) error {
root := r.methodNode(method)
if root == nil {
return fmt.Errorf("unsupported method: %s", method)
}
if path == "/" {
root.handler = h
return nil
}
current := root
pos := 0
lastWC := false
count := uint8(0)
for {
seg, newPos, more := readSegment(path, pos)
if seg == "" {
break
}
isDyn := len(seg) > 0 && seg[0] == ':'
isWC := len(seg) > 0 && seg[0] == '*'
if isWC {
if lastWC || more {
return fmt.Errorf("wildcard must be the last segment in the path")
}
lastWC = true
}
if isDyn || isWC {
count++
}
var child *node
for _, c := range current.children {
if c.segment == seg {
child = c
break
}
}
if child == nil {
child = &node{segment: seg, isDynamic: isDyn, isWildcard: isWC}
current.children = append(current.children, child)
}
if child.maxParams < count {
child.maxParams = count
}
current = child
pos = newPos
}
current.handler = h
return nil
}
func (r *Router) lookup(method, path string) (Handler, []string, bool) {
root := r.methodNode(method)
if root == nil {
return nil, nil, false
}
if path == "/" {
return root.handler, nil, root.handler != nil
}
buffer := r.paramsBuffer
if cap(buffer) < int(root.maxParams) {
buffer = make([]string, root.maxParams)
r.paramsBuffer = buffer
}
buffer = buffer[:0]
h, paramCount, found := match(root, path, 0, &buffer)
if !found {
return nil, nil, false
}
return h, buffer[:paramCount], true
}
// HTTPServer methods
func (hs *HTTPServer) handleRequest(ctx *fasthttp.RequestCtx) {
method := string(ctx.Method())
path := string(ctx.Path())
handler, params, found := hs.router.lookup(method, path)
if !found {
ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.WriteString("Not Found")
return
}
handler(ctx, params)
}
func (hs *HTTPServer) callLuaHandler(ctx *fasthttp.RequestCtx, params []string, handler *luajit.LuaFunction) {
hs.luaMu.Lock()
defer hs.luaMu.Unlock()
request := map[string]interface{}{
"method": string(ctx.Method()),
"path": string(ctx.Path()),
"query": string(ctx.QueryArgs().QueryString()),
"headers": make(map[string]string),
"body": string(ctx.PostBody()),
"remote": ctx.RemoteAddr().String(),
"params": params,
}
headers := request["headers"].(map[string]string)
ctx.Request.Header.VisitAll(func(key, value []byte) {
headers[string(key)] = string(value)
})
response := map[string]interface{}{
"status": 200,
"headers": make(map[string]string),
"body": "",
}
results, err := handler.Call(request, response)
if err != nil {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
ctx.WriteString(fmt.Sprintf("Handler error: %v", err))
return
}
if len(results) > 0 {
if respMap, ok := results[0].(map[string]interface{}); ok {
response = respMap
}
}
if status, ok := response["status"].(int); ok {
ctx.SetStatusCode(status)
} else if status, ok := response["status"].(float64); ok {
ctx.SetStatusCode(int(status))
}
if headers, ok := response["headers"].(map[string]interface{}); ok {
for k, v := range headers {
if str, ok := v.(string); ok {
ctx.Response.Header.Set(k, str)
}
}
}
if body, ok := response["body"].(string); ok {
ctx.WriteString(body)
}
}
// Utility functions
func readSegment(path string, start int) (segment string, end int, hasMore bool) {
if start >= len(path) {
return "", start, false
}
if path[start] == '/' {
start++
}
if start >= len(path) {
return "", start, false
}
end = start
for end < len(path) && path[end] != '/' {
end++
}
return path[start:end], end, end < len(path)
}
func match(current *node, path string, start int, params *[]string) (Handler, int, bool) {
paramCount := 0
for _, c := range current.children {
if c.isWildcard {
rem := path[start:]
if len(rem) > 0 && rem[0] == '/' {
rem = rem[1:]
}
*params = append(*params, rem)
return c.handler, 1, c.handler != nil
}
}
seg, pos, more := readSegment(path, start)
if seg == "" {
return current.handler, 0, current.handler != nil
}
for _, c := range current.children {
if c.segment == seg || c.isDynamic {
if c.isDynamic {
*params = append(*params, seg)
paramCount++
}
if !more {
return c.handler, paramCount, c.handler != nil
}
h, nestedCount, ok := match(c, path, pos, params)
if ok {
return h, paramCount + nestedCount, true
}
}
}
return nil, 0, false
}