347 lines
6.0 KiB
Go
347 lines
6.0 KiB
Go
package http
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
|
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
)
|
|
|
|
type StateCreator func() (*luajit.State, error)
|
|
|
|
type Request struct {
|
|
Method string
|
|
Path string
|
|
Query map[string]string
|
|
Headers map[string]string
|
|
Body string
|
|
}
|
|
|
|
type Response struct {
|
|
StatusCode int
|
|
Headers map[string]string
|
|
Body string
|
|
}
|
|
|
|
type Worker struct {
|
|
state *luajit.State
|
|
id int
|
|
}
|
|
|
|
type WorkerPool struct {
|
|
workers chan *Worker
|
|
masterState *luajit.State
|
|
stateCreator StateCreator
|
|
workerCount int
|
|
closed bool
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
var (
|
|
requestPool = sync.Pool{
|
|
New: func() any {
|
|
return &Request{
|
|
Query: make(map[string]string),
|
|
Headers: make(map[string]string),
|
|
}
|
|
},
|
|
}
|
|
|
|
responsePool = sync.Pool{
|
|
New: func() any {
|
|
return &Response{
|
|
Headers: make(map[string]string),
|
|
}
|
|
},
|
|
}
|
|
)
|
|
|
|
func GetRequest() *Request {
|
|
req := requestPool.Get().(*Request)
|
|
for k := range req.Query {
|
|
delete(req.Query, k)
|
|
}
|
|
for k := range req.Headers {
|
|
delete(req.Headers, k)
|
|
}
|
|
req.Method = ""
|
|
req.Path = ""
|
|
req.Body = ""
|
|
return req
|
|
}
|
|
|
|
func PutRequest(req *Request) {
|
|
requestPool.Put(req)
|
|
}
|
|
|
|
func GetResponse() *Response {
|
|
resp := responsePool.Get().(*Response)
|
|
for k := range resp.Headers {
|
|
delete(resp.Headers, k)
|
|
}
|
|
resp.StatusCode = 200
|
|
resp.Body = ""
|
|
return resp
|
|
}
|
|
|
|
func PutResponse(resp *Response) {
|
|
responsePool.Put(resp)
|
|
}
|
|
|
|
func NewWorkerPool(size int, masterState *luajit.State, stateCreator StateCreator) (*WorkerPool, error) {
|
|
pool := &WorkerPool{
|
|
workers: make(chan *Worker, size),
|
|
masterState: masterState,
|
|
stateCreator: stateCreator,
|
|
workerCount: size,
|
|
}
|
|
|
|
for i := range size {
|
|
worker, err := pool.createWorker(i)
|
|
if err != nil {
|
|
pool.Close()
|
|
return nil, fmt.Errorf("failed to create worker %d: %w", i, err)
|
|
}
|
|
pool.workers <- worker
|
|
}
|
|
|
|
return pool, nil
|
|
}
|
|
|
|
func (p *WorkerPool) createWorker(id int) (*Worker, error) {
|
|
workerState, err := p.stateCreator()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create worker state: %w", err)
|
|
}
|
|
|
|
return &Worker{
|
|
state: workerState,
|
|
id: id,
|
|
}, nil
|
|
}
|
|
|
|
func (p *WorkerPool) Get() *Worker {
|
|
p.mu.RLock()
|
|
if p.closed {
|
|
p.mu.RUnlock()
|
|
return nil
|
|
}
|
|
p.mu.RUnlock()
|
|
|
|
select {
|
|
case worker := <-p.workers:
|
|
return worker
|
|
default:
|
|
worker, err := p.createWorker(-1)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return worker
|
|
}
|
|
}
|
|
|
|
func (p *WorkerPool) Put(worker *Worker) {
|
|
if worker == nil {
|
|
return
|
|
}
|
|
|
|
p.mu.RLock()
|
|
defer p.mu.RUnlock()
|
|
|
|
if p.closed {
|
|
worker.Close()
|
|
return
|
|
}
|
|
|
|
if worker.id == -1 {
|
|
worker.Close()
|
|
return
|
|
}
|
|
|
|
select {
|
|
case p.workers <- worker:
|
|
default:
|
|
worker.Close()
|
|
}
|
|
}
|
|
|
|
func (p *WorkerPool) SyncRoutes(routesData any) {
|
|
p.mu.RLock()
|
|
defer p.mu.RUnlock()
|
|
|
|
if p.closed {
|
|
return
|
|
}
|
|
|
|
// Sync routes to all workers
|
|
workers := make([]*Worker, 0, p.workerCount)
|
|
|
|
// Collect all workers
|
|
for {
|
|
select {
|
|
case worker := <-p.workers:
|
|
workers = append(workers, worker)
|
|
default:
|
|
goto syncWorkers
|
|
}
|
|
}
|
|
|
|
syncWorkers:
|
|
// Sync and return workers
|
|
for _, worker := range workers {
|
|
if worker.state != nil {
|
|
worker.state.PushValue(routesData)
|
|
worker.state.SetGlobal("_http_routes_data")
|
|
|
|
worker.state.GetGlobal("_http_sync_worker_routes")
|
|
if worker.state.IsFunction(-1) {
|
|
worker.state.Call(0, 0)
|
|
} else {
|
|
worker.state.Pop(1)
|
|
}
|
|
}
|
|
|
|
select {
|
|
case p.workers <- worker:
|
|
default:
|
|
worker.Close()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *WorkerPool) Close() {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
if p.closed {
|
|
return
|
|
}
|
|
p.closed = true
|
|
|
|
// Collect all workers first
|
|
workers := make([]*Worker, 0, len(p.workers))
|
|
close(p.workers)
|
|
for worker := range p.workers {
|
|
workers = append(workers, worker)
|
|
}
|
|
|
|
// Close all workers in parallel
|
|
var wg sync.WaitGroup
|
|
for _, worker := range workers {
|
|
wg.Add(1)
|
|
go func(w *Worker) {
|
|
defer wg.Done()
|
|
w.Close()
|
|
}(worker)
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func (w *Worker) Close() {
|
|
if w.state != nil {
|
|
w.state.Close()
|
|
w.state = nil
|
|
}
|
|
}
|
|
|
|
func (w *Worker) HandleRequest(req *Request, resp *Response) error {
|
|
if w.state == nil {
|
|
return fmt.Errorf("worker state is nil")
|
|
}
|
|
|
|
// Create request table
|
|
w.state.NewTable()
|
|
w.state.PushString("method")
|
|
w.state.PushString(req.Method)
|
|
w.state.SetTable(-3)
|
|
|
|
w.state.PushString("path")
|
|
w.state.PushString(req.Path)
|
|
w.state.SetTable(-3)
|
|
|
|
w.state.PushString("body")
|
|
w.state.PushString(req.Body)
|
|
w.state.SetTable(-3)
|
|
|
|
// Query params
|
|
w.state.PushString("query")
|
|
w.state.NewTable()
|
|
for k, v := range req.Query {
|
|
w.state.PushString(k)
|
|
w.state.PushString(v)
|
|
w.state.SetTable(-3)
|
|
}
|
|
w.state.SetTable(-3)
|
|
|
|
// Headers
|
|
w.state.PushString("headers")
|
|
w.state.NewTable()
|
|
for k, v := range req.Headers {
|
|
w.state.PushString(k)
|
|
w.state.PushString(v)
|
|
w.state.SetTable(-3)
|
|
}
|
|
w.state.SetTable(-3)
|
|
|
|
// Create response table
|
|
w.state.NewTable()
|
|
w.state.PushString("status")
|
|
w.state.PushNumber(200)
|
|
w.state.SetTable(-3)
|
|
|
|
w.state.PushString("body")
|
|
w.state.PushString("")
|
|
w.state.SetTable(-3)
|
|
|
|
w.state.PushString("headers")
|
|
w.state.NewTable()
|
|
w.state.SetTable(-3)
|
|
|
|
// Call _http_handle_request(req, res) - pure Lua routing
|
|
w.state.GetGlobal("_http_handle_request")
|
|
if !w.state.IsFunction(-1) {
|
|
w.state.Pop(3)
|
|
resp.StatusCode = 500
|
|
resp.Body = "HTTP handler not initialized"
|
|
return nil
|
|
}
|
|
|
|
w.state.PushCopy(-3) // request
|
|
w.state.PushCopy(-3) // response
|
|
|
|
if err := w.state.Call(2, 0); err != nil {
|
|
w.state.Pop(2)
|
|
resp.StatusCode = 500
|
|
resp.Body = fmt.Sprintf("Handler error: %v", err)
|
|
return nil
|
|
}
|
|
|
|
// Extract response
|
|
w.state.GetField(-1, "status")
|
|
if w.state.IsNumber(-1) {
|
|
resp.StatusCode = int(w.state.ToNumber(-1))
|
|
}
|
|
w.state.Pop(1)
|
|
|
|
w.state.GetField(-1, "body")
|
|
if w.state.IsString(-1) {
|
|
resp.Body = w.state.ToString(-1)
|
|
}
|
|
w.state.Pop(1)
|
|
|
|
w.state.GetField(-1, "headers")
|
|
if w.state.IsTable(-1) {
|
|
w.state.PushNil()
|
|
for w.state.Next(-2) {
|
|
if w.state.IsString(-2) && w.state.IsString(-1) {
|
|
resp.Headers[w.state.ToString(-2)] = w.state.ToString(-1)
|
|
}
|
|
w.state.Pop(1)
|
|
}
|
|
}
|
|
w.state.Pop(1)
|
|
|
|
w.state.Pop(2) // Clean up request and response tables
|
|
return nil
|
|
}
|