496 lines
11 KiB
Go
496 lines
11 KiB
Go
package http
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
// Handler is a fasthttp request handler with parameters
|
|
type Handler func(ctx *fasthttp.RequestCtx, params []string)
|
|
|
|
type node struct {
|
|
segment string
|
|
handler Handler
|
|
children []*node
|
|
isDynamic bool
|
|
isWildcard bool
|
|
maxParams uint8
|
|
}
|
|
|
|
type Router struct {
|
|
get, post, put, patch, delete *node
|
|
paramsBuffer []string
|
|
}
|
|
|
|
func newRouter() *Router {
|
|
return &Router{
|
|
get: &node{},
|
|
post: &node{},
|
|
put: &node{},
|
|
patch: &node{},
|
|
delete: &node{},
|
|
paramsBuffer: make([]string, 64),
|
|
}
|
|
}
|
|
|
|
// HTTPServer with efficient serialized Lua handling
|
|
type HTTPServer struct {
|
|
server *fasthttp.Server
|
|
router *Router
|
|
addr string
|
|
running bool
|
|
mu sync.RWMutex
|
|
luaMu sync.Mutex // Serializes Lua calls
|
|
}
|
|
|
|
var (
|
|
serverRegistry = struct {
|
|
sync.RWMutex
|
|
servers map[int]*HTTPServer
|
|
nextID int
|
|
}{
|
|
servers: make(map[int]*HTTPServer),
|
|
nextID: 1,
|
|
}
|
|
)
|
|
|
|
func GetHTTPFunctions() map[string]luajit.GoFunction {
|
|
return map[string]luajit.GoFunction{
|
|
"http_create_server": func(s *luajit.State) int {
|
|
server := &HTTPServer{
|
|
server: &fasthttp.Server{
|
|
ReadTimeout: 10 * time.Second,
|
|
WriteTimeout: 10 * time.Second,
|
|
IdleTimeout: 60 * time.Second,
|
|
},
|
|
router: newRouter(),
|
|
}
|
|
|
|
server.server.Handler = server.handleRequest
|
|
|
|
serverRegistry.Lock()
|
|
id := serverRegistry.nextID
|
|
serverRegistry.nextID++
|
|
serverRegistry.servers[id] = server
|
|
serverRegistry.Unlock()
|
|
|
|
s.PushNumber(float64(id))
|
|
return 1
|
|
},
|
|
|
|
"http_server_listen": func(s *luajit.State) int {
|
|
if err := s.CheckExactArgs(2); err != nil {
|
|
return s.PushError("http_server_listen: %v", err)
|
|
}
|
|
|
|
serverID, err := s.SafeToNumber(1)
|
|
if err != nil || serverID != float64(int(serverID)) {
|
|
return s.PushError("http_server_listen: server ID must be an integer")
|
|
}
|
|
|
|
addr, err := s.SafeToString(2)
|
|
if err != nil {
|
|
return s.PushError("http_server_listen: address must be a string")
|
|
}
|
|
|
|
serverRegistry.RLock()
|
|
server, exists := serverRegistry.servers[int(serverID)]
|
|
serverRegistry.RUnlock()
|
|
|
|
if !exists {
|
|
return s.PushError("http_server_listen: server not found")
|
|
}
|
|
|
|
server.mu.Lock()
|
|
if server.running {
|
|
server.mu.Unlock()
|
|
return s.PushError("http_server_listen: server already running")
|
|
}
|
|
|
|
server.addr = addr
|
|
server.running = true
|
|
server.mu.Unlock()
|
|
|
|
go func() {
|
|
err := server.server.ListenAndServe(addr)
|
|
if err != nil {
|
|
server.mu.Lock()
|
|
server.running = false
|
|
server.mu.Unlock()
|
|
}
|
|
}()
|
|
|
|
s.PushBoolean(true)
|
|
return 1
|
|
},
|
|
|
|
"http_server_stop": func(s *luajit.State) int {
|
|
if err := s.CheckMinArgs(1); err != nil {
|
|
return s.PushError("http_server_stop: %v", err)
|
|
}
|
|
|
|
serverID, err := s.SafeToNumber(1)
|
|
if err != nil || serverID != float64(int(serverID)) {
|
|
return s.PushError("http_server_stop: server ID must be an integer")
|
|
}
|
|
|
|
serverRegistry.RLock()
|
|
server, exists := serverRegistry.servers[int(serverID)]
|
|
serverRegistry.RUnlock()
|
|
|
|
if !exists {
|
|
return s.PushError("http_server_stop: server not found")
|
|
}
|
|
|
|
server.mu.Lock()
|
|
if !server.running {
|
|
server.mu.Unlock()
|
|
s.PushBoolean(false)
|
|
return 1
|
|
}
|
|
|
|
server.running = false
|
|
server.mu.Unlock()
|
|
|
|
if err := server.server.Shutdown(); err != nil {
|
|
return s.PushError("http_server_stop: %v", err)
|
|
}
|
|
|
|
s.PushBoolean(true)
|
|
return 1
|
|
},
|
|
|
|
"http_server_get": createRouteHandler("GET"),
|
|
"http_server_post": createRouteHandler("POST"),
|
|
"http_server_put": createRouteHandler("PUT"),
|
|
"http_server_patch": createRouteHandler("PATCH"),
|
|
"http_server_delete": createRouteHandler("DELETE"),
|
|
|
|
"http_server_is_running": func(s *luajit.State) int {
|
|
if err := s.CheckMinArgs(1); err != nil {
|
|
return s.PushError("http_server_is_running: %v", err)
|
|
}
|
|
|
|
serverID, err := s.SafeToNumber(1)
|
|
if err != nil || serverID != float64(int(serverID)) {
|
|
return s.PushError("http_server_is_running: server ID must be an integer")
|
|
}
|
|
|
|
serverRegistry.RLock()
|
|
server, exists := serverRegistry.servers[int(serverID)]
|
|
serverRegistry.RUnlock()
|
|
|
|
if !exists {
|
|
s.PushBoolean(false)
|
|
return 1
|
|
}
|
|
|
|
server.mu.RLock()
|
|
running := server.running
|
|
server.mu.RUnlock()
|
|
|
|
s.PushBoolean(running)
|
|
return 1
|
|
},
|
|
|
|
"http_cleanup_servers": func(s *luajit.State) int {
|
|
serverRegistry.Lock()
|
|
for id, server := range serverRegistry.servers {
|
|
server.mu.Lock()
|
|
if server.running {
|
|
server.server.Shutdown()
|
|
server.running = false
|
|
}
|
|
server.mu.Unlock()
|
|
delete(serverRegistry.servers, id)
|
|
}
|
|
serverRegistry.Unlock()
|
|
|
|
s.PushBoolean(true)
|
|
return 1
|
|
},
|
|
}
|
|
}
|
|
|
|
func createRouteHandler(method string) luajit.GoFunction {
|
|
return func(s *luajit.State) int {
|
|
if err := s.CheckExactArgs(3); err != nil {
|
|
return s.PushError("http_server_%s: %v", method, err)
|
|
}
|
|
|
|
serverID, err := s.SafeToNumber(1)
|
|
if err != nil || serverID != float64(int(serverID)) {
|
|
return s.PushError("http_server_%s: server ID must be an integer", method)
|
|
}
|
|
|
|
path, err := s.SafeToString(2)
|
|
if err != nil {
|
|
return s.PushError("http_server_%s: path must be a string", method)
|
|
}
|
|
|
|
if !s.IsFunction(3) {
|
|
return s.PushError("http_server_%s: handler must be a function", method)
|
|
}
|
|
|
|
serverRegistry.RLock()
|
|
server, exists := serverRegistry.servers[int(serverID)]
|
|
serverRegistry.RUnlock()
|
|
|
|
if !exists {
|
|
return s.PushError("http_server_%s: server not found", method)
|
|
}
|
|
|
|
luaFunc, err := s.StoreLuaFunction(3)
|
|
if err != nil {
|
|
return s.PushError("http_server_%s: failed to store function: %v", method, err)
|
|
}
|
|
|
|
handler := func(ctx *fasthttp.RequestCtx, params []string) {
|
|
server.callLuaHandler(ctx, params, luaFunc)
|
|
}
|
|
|
|
if err := server.router.addRoute(method, path, handler); err != nil {
|
|
return s.PushError("http_server_%s: failed to add route: %v", method, err)
|
|
}
|
|
|
|
s.PushBoolean(true)
|
|
return 1
|
|
}
|
|
}
|
|
|
|
// Router methods
|
|
func (r *Router) methodNode(method string) *node {
|
|
switch method {
|
|
case "GET":
|
|
return r.get
|
|
case "POST":
|
|
return r.post
|
|
case "PUT":
|
|
return r.put
|
|
case "PATCH":
|
|
return r.patch
|
|
case "DELETE":
|
|
return r.delete
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (r *Router) addRoute(method, path string, h Handler) error {
|
|
root := r.methodNode(method)
|
|
if root == nil {
|
|
return fmt.Errorf("unsupported method: %s", method)
|
|
}
|
|
|
|
if path == "/" {
|
|
root.handler = h
|
|
return nil
|
|
}
|
|
|
|
current := root
|
|
pos := 0
|
|
lastWC := false
|
|
count := uint8(0)
|
|
|
|
for {
|
|
seg, newPos, more := readSegment(path, pos)
|
|
if seg == "" {
|
|
break
|
|
}
|
|
|
|
isDyn := len(seg) > 0 && seg[0] == ':'
|
|
isWC := len(seg) > 0 && seg[0] == '*'
|
|
|
|
if isWC {
|
|
if lastWC || more {
|
|
return fmt.Errorf("wildcard must be the last segment in the path")
|
|
}
|
|
lastWC = true
|
|
}
|
|
|
|
if isDyn || isWC {
|
|
count++
|
|
}
|
|
|
|
var child *node
|
|
for _, c := range current.children {
|
|
if c.segment == seg {
|
|
child = c
|
|
break
|
|
}
|
|
}
|
|
|
|
if child == nil {
|
|
child = &node{segment: seg, isDynamic: isDyn, isWildcard: isWC}
|
|
current.children = append(current.children, child)
|
|
}
|
|
|
|
if child.maxParams < count {
|
|
child.maxParams = count
|
|
}
|
|
|
|
current = child
|
|
pos = newPos
|
|
}
|
|
|
|
current.handler = h
|
|
return nil
|
|
}
|
|
|
|
func (r *Router) lookup(method, path string) (Handler, []string, bool) {
|
|
root := r.methodNode(method)
|
|
if root == nil {
|
|
return nil, nil, false
|
|
}
|
|
|
|
if path == "/" {
|
|
return root.handler, nil, root.handler != nil
|
|
}
|
|
|
|
buffer := r.paramsBuffer
|
|
if cap(buffer) < int(root.maxParams) {
|
|
buffer = make([]string, root.maxParams)
|
|
r.paramsBuffer = buffer
|
|
}
|
|
buffer = buffer[:0]
|
|
|
|
h, paramCount, found := match(root, path, 0, &buffer)
|
|
if !found {
|
|
return nil, nil, false
|
|
}
|
|
|
|
return h, buffer[:paramCount], true
|
|
}
|
|
|
|
// HTTPServer methods
|
|
func (hs *HTTPServer) handleRequest(ctx *fasthttp.RequestCtx) {
|
|
method := string(ctx.Method())
|
|
path := string(ctx.Path())
|
|
|
|
handler, params, found := hs.router.lookup(method, path)
|
|
if !found {
|
|
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
|
ctx.WriteString("Not Found")
|
|
return
|
|
}
|
|
|
|
handler(ctx, params)
|
|
}
|
|
|
|
func (hs *HTTPServer) callLuaHandler(ctx *fasthttp.RequestCtx, params []string, handler *luajit.LuaFunction) {
|
|
hs.luaMu.Lock()
|
|
defer hs.luaMu.Unlock()
|
|
|
|
request := map[string]interface{}{
|
|
"method": string(ctx.Method()),
|
|
"path": string(ctx.Path()),
|
|
"query": string(ctx.QueryArgs().QueryString()),
|
|
"headers": make(map[string]string),
|
|
"body": string(ctx.PostBody()),
|
|
"remote": ctx.RemoteAddr().String(),
|
|
"params": params,
|
|
}
|
|
|
|
headers := request["headers"].(map[string]string)
|
|
ctx.Request.Header.VisitAll(func(key, value []byte) {
|
|
headers[string(key)] = string(value)
|
|
})
|
|
|
|
response := map[string]interface{}{
|
|
"status": 200,
|
|
"headers": make(map[string]string),
|
|
"body": "",
|
|
}
|
|
|
|
results, err := handler.Call(request, response)
|
|
if err != nil {
|
|
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
|
ctx.WriteString(fmt.Sprintf("Handler error: %v", err))
|
|
return
|
|
}
|
|
|
|
if len(results) > 0 {
|
|
if respMap, ok := results[0].(map[string]interface{}); ok {
|
|
response = respMap
|
|
}
|
|
}
|
|
|
|
if status, ok := response["status"].(int); ok {
|
|
ctx.SetStatusCode(status)
|
|
} else if status, ok := response["status"].(float64); ok {
|
|
ctx.SetStatusCode(int(status))
|
|
}
|
|
|
|
if headers, ok := response["headers"].(map[string]interface{}); ok {
|
|
for k, v := range headers {
|
|
if str, ok := v.(string); ok {
|
|
ctx.Response.Header.Set(k, str)
|
|
}
|
|
}
|
|
}
|
|
|
|
if body, ok := response["body"].(string); ok {
|
|
ctx.WriteString(body)
|
|
}
|
|
}
|
|
|
|
// Utility functions
|
|
func readSegment(path string, start int) (segment string, end int, hasMore bool) {
|
|
if start >= len(path) {
|
|
return "", start, false
|
|
}
|
|
if path[start] == '/' {
|
|
start++
|
|
}
|
|
if start >= len(path) {
|
|
return "", start, false
|
|
}
|
|
end = start
|
|
for end < len(path) && path[end] != '/' {
|
|
end++
|
|
}
|
|
return path[start:end], end, end < len(path)
|
|
}
|
|
|
|
func match(current *node, path string, start int, params *[]string) (Handler, int, bool) {
|
|
paramCount := 0
|
|
|
|
for _, c := range current.children {
|
|
if c.isWildcard {
|
|
rem := path[start:]
|
|
if len(rem) > 0 && rem[0] == '/' {
|
|
rem = rem[1:]
|
|
}
|
|
*params = append(*params, rem)
|
|
return c.handler, 1, c.handler != nil
|
|
}
|
|
}
|
|
|
|
seg, pos, more := readSegment(path, start)
|
|
if seg == "" {
|
|
return current.handler, 0, current.handler != nil
|
|
}
|
|
|
|
for _, c := range current.children {
|
|
if c.segment == seg || c.isDynamic {
|
|
if c.isDynamic {
|
|
*params = append(*params, seg)
|
|
paramCount++
|
|
}
|
|
if !more {
|
|
return c.handler, paramCount, c.handler != nil
|
|
}
|
|
h, nestedCount, ok := match(c, path, pos, params)
|
|
if ok {
|
|
return h, paramCount + nestedCount, true
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil, 0, false
|
|
}
|