347 lines
6.0 KiB
Go

package http
import (
"fmt"
"sync"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
type StateCreator func() (*luajit.State, error)
type Request struct {
Method string
Path string
Query map[string]string
Headers map[string]string
Body string
}
type Response struct {
StatusCode int
Headers map[string]string
Body string
}
type Worker struct {
state *luajit.State
id int
}
type WorkerPool struct {
workers chan *Worker
masterState *luajit.State
stateCreator StateCreator
workerCount int
closed bool
mu sync.RWMutex
}
var (
requestPool = sync.Pool{
New: func() any {
return &Request{
Query: make(map[string]string),
Headers: make(map[string]string),
}
},
}
responsePool = sync.Pool{
New: func() any {
return &Response{
Headers: make(map[string]string),
}
},
}
)
func GetRequest() *Request {
req := requestPool.Get().(*Request)
for k := range req.Query {
delete(req.Query, k)
}
for k := range req.Headers {
delete(req.Headers, k)
}
req.Method = ""
req.Path = ""
req.Body = ""
return req
}
func PutRequest(req *Request) {
requestPool.Put(req)
}
func GetResponse() *Response {
resp := responsePool.Get().(*Response)
for k := range resp.Headers {
delete(resp.Headers, k)
}
resp.StatusCode = 200
resp.Body = ""
return resp
}
func PutResponse(resp *Response) {
responsePool.Put(resp)
}
func NewWorkerPool(size int, masterState *luajit.State, stateCreator StateCreator) (*WorkerPool, error) {
pool := &WorkerPool{
workers: make(chan *Worker, size),
masterState: masterState,
stateCreator: stateCreator,
workerCount: size,
}
for i := range size {
worker, err := pool.createWorker(i)
if err != nil {
pool.Close()
return nil, fmt.Errorf("failed to create worker %d: %w", i, err)
}
pool.workers <- worker
}
return pool, nil
}
func (p *WorkerPool) createWorker(id int) (*Worker, error) {
workerState, err := p.stateCreator()
if err != nil {
return nil, fmt.Errorf("failed to create worker state: %w", err)
}
return &Worker{
state: workerState,
id: id,
}, nil
}
func (p *WorkerPool) Get() *Worker {
p.mu.RLock()
if p.closed {
p.mu.RUnlock()
return nil
}
p.mu.RUnlock()
select {
case worker := <-p.workers:
return worker
default:
worker, err := p.createWorker(-1)
if err != nil {
return nil
}
return worker
}
}
func (p *WorkerPool) Put(worker *Worker) {
if worker == nil {
return
}
p.mu.RLock()
defer p.mu.RUnlock()
if p.closed {
worker.Close()
return
}
if worker.id == -1 {
worker.Close()
return
}
select {
case p.workers <- worker:
default:
worker.Close()
}
}
func (p *WorkerPool) SyncRoutes(routesData any) {
p.mu.RLock()
defer p.mu.RUnlock()
if p.closed {
return
}
// Sync routes to all workers
workers := make([]*Worker, 0, p.workerCount)
// Collect all workers
for {
select {
case worker := <-p.workers:
workers = append(workers, worker)
default:
goto syncWorkers
}
}
syncWorkers:
// Sync and return workers
for _, worker := range workers {
if worker.state != nil {
worker.state.PushValue(routesData)
worker.state.SetGlobal("_http_routes_data")
worker.state.GetGlobal("_http_sync_worker_routes")
if worker.state.IsFunction(-1) {
worker.state.Call(0, 0)
} else {
worker.state.Pop(1)
}
}
select {
case p.workers <- worker:
default:
worker.Close()
}
}
}
func (p *WorkerPool) Close() {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return
}
p.closed = true
// Collect all workers first
workers := make([]*Worker, 0, len(p.workers))
close(p.workers)
for worker := range p.workers {
workers = append(workers, worker)
}
// Close all workers in parallel
var wg sync.WaitGroup
for _, worker := range workers {
wg.Add(1)
go func(w *Worker) {
defer wg.Done()
w.Close()
}(worker)
}
wg.Wait()
}
func (w *Worker) Close() {
if w.state != nil {
w.state.Close()
w.state = nil
}
}
func (w *Worker) HandleRequest(req *Request, resp *Response) error {
if w.state == nil {
return fmt.Errorf("worker state is nil")
}
// Create request table
w.state.NewTable()
w.state.PushString("method")
w.state.PushString(req.Method)
w.state.SetTable(-3)
w.state.PushString("path")
w.state.PushString(req.Path)
w.state.SetTable(-3)
w.state.PushString("body")
w.state.PushString(req.Body)
w.state.SetTable(-3)
// Query params
w.state.PushString("query")
w.state.NewTable()
for k, v := range req.Query {
w.state.PushString(k)
w.state.PushString(v)
w.state.SetTable(-3)
}
w.state.SetTable(-3)
// Headers
w.state.PushString("headers")
w.state.NewTable()
for k, v := range req.Headers {
w.state.PushString(k)
w.state.PushString(v)
w.state.SetTable(-3)
}
w.state.SetTable(-3)
// Create response table
w.state.NewTable()
w.state.PushString("status")
w.state.PushNumber(200)
w.state.SetTable(-3)
w.state.PushString("body")
w.state.PushString("")
w.state.SetTable(-3)
w.state.PushString("headers")
w.state.NewTable()
w.state.SetTable(-3)
// Call _http_handle_request(req, res) - pure Lua routing
w.state.GetGlobal("_http_handle_request")
if !w.state.IsFunction(-1) {
w.state.Pop(3)
resp.StatusCode = 500
resp.Body = "HTTP handler not initialized"
return nil
}
w.state.PushCopy(-3) // request
w.state.PushCopy(-3) // response
if err := w.state.Call(2, 0); err != nil {
w.state.Pop(2)
resp.StatusCode = 500
resp.Body = fmt.Sprintf("Handler error: %v", err)
return nil
}
// Extract response
w.state.GetField(-1, "status")
if w.state.IsNumber(-1) {
resp.StatusCode = int(w.state.ToNumber(-1))
}
w.state.Pop(1)
w.state.GetField(-1, "body")
if w.state.IsString(-1) {
resp.Body = w.state.ToString(-1)
}
w.state.Pop(1)
w.state.GetField(-1, "headers")
if w.state.IsTable(-1) {
w.state.PushNil()
for w.state.Next(-2) {
if w.state.IsString(-2) && w.state.IsString(-1) {
resp.Headers[w.state.ToString(-2)] = w.state.ToString(-1)
}
w.state.Pop(1)
}
}
w.state.Pop(1)
w.state.Pop(2) // Clean up request and response tables
return nil
}