Moonshark/http/http.go
2025-07-14 17:36:59 -05:00

598 lines
14 KiB
Go

package http
import (
"crypto/rand"
_ "embed"
"encoding/base64"
"fmt"
"mime/multipart"
"strings"
"sync"
"time"
"Moonshark/http/router"
"Moonshark/http/sessions"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/goccy/go-json"
"github.com/valyala/fasthttp"
)
//go:embed http.lua
var httpLuaCode string
// HandlerFunc represents a Lua handler
type HandlerFunc struct {
bytecode []byte
funcRef int
name string
isFunction bool
}
// Server with single state for function handler compatibility
type Server struct {
server *fasthttp.Server
router *router.Router
sessions *sessions.SessionManager
state *luajit.State
stateMu sync.Mutex
handlers map[int]*HandlerFunc
handlersMu sync.RWMutex
funcCounter int
}
// RequestContext with lazy parsing
type RequestContext struct {
ctx *fasthttp.RequestCtx
params *router.Params
session *sessions.Session
parsedForm map[string]any
formOnce sync.Once
}
var globalServer *Server
func NewServer(state *luajit.State) *Server {
return &Server{
router: router.New(),
sessions: sessions.NewSessionManager(10000),
state: state,
handlers: make(map[int]*HandlerFunc),
}
}
func RegisterHTTPFunctions(L *luajit.State) error {
globalServer = NewServer(L)
functions := map[string]luajit.GoFunction{
"__http_listen": globalServer.httpListen,
"__http_route": globalServer.httpRoute,
}
for name, fn := range functions {
if err := L.RegisterGoFunction(name, fn); err != nil {
return err
}
}
return L.DoString(httpLuaCode)
}
func (s *Server) httpListen(state *luajit.State) int {
port, err := state.SafeToNumber(1)
if err != nil {
return state.PushError("listen: port must be number")
}
s.server = &fasthttp.Server{
Handler: s.fastRequestHandler,
Name: "Moonshark/2.0",
Concurrency: 256 * 1024,
ReadBufferSize: 4096,
WriteBufferSize: 4096,
ReduceMemoryUsage: true,
NoDefaultServerHeader: true,
}
addr := fmt.Sprintf(":%d", int(port))
go func() {
if err := s.server.ListenAndServe(addr); err != nil {
fmt.Printf("Server error: %v\n", err)
}
}()
fmt.Printf("Server listening on port %d\n", int(port))
state.PushBoolean(true)
return 1
}
func (s *Server) httpRoute(state *luajit.State) int {
method, err := state.SafeToString(1)
if err != nil {
return state.PushError("route: method must be string")
}
path, err := state.SafeToString(2)
if err != nil {
return state.PushError("route: path must be string")
}
s.funcCounter++
handlerID := s.funcCounter
if state.IsFunction(3) {
// Function handler - store reference
state.PushCopy(3)
funcRef := s.storeFunction(state)
s.handlersMu.Lock()
s.handlers[handlerID] = &HandlerFunc{
funcRef: funcRef,
name: fmt.Sprintf("%s %s", method, path),
isFunction: true,
}
s.handlersMu.Unlock()
} else {
// String handler - compile to bytecode
handlerCode, err := state.SafeToString(3)
if err != nil {
return state.PushError("route: handler must be function or string")
}
bytecode, err := state.CompileBytecode(handlerCode, fmt.Sprintf("handler_%s_%s", method, path))
if err != nil {
return state.PushError("route: failed to compile handler: %s", err.Error())
}
s.handlersMu.Lock()
s.handlers[handlerID] = &HandlerFunc{
bytecode: bytecode,
name: fmt.Sprintf("%s %s", method, path),
isFunction: false,
}
s.handlersMu.Unlock()
}
// Add route to router
if err := s.router.AddRoute(strings.ToUpper(method), path, handlerID); err != nil {
return state.PushError("route: failed to add route: %s", err.Error())
}
state.PushBoolean(true)
return 1
}
func (s *Server) storeFunction(state *luajit.State) int {
state.GetGlobal("__moonshark_functions")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.PushCopy(-1)
state.SetGlobal("__moonshark_functions")
}
s.funcCounter++
state.PushNumber(float64(s.funcCounter))
state.PushCopy(-3)
state.SetTable(-3)
state.Pop(2)
return s.funcCounter
}
func (s *Server) getFunction(state *luajit.State, ref int) bool {
state.GetGlobal("__moonshark_functions")
if state.IsNil(-1) {
state.Pop(1)
return false
}
state.PushNumber(float64(ref))
state.GetTable(-2)
isFunc := state.IsFunction(-1)
if !isFunc {
state.Pop(2)
return false
}
state.Remove(-2)
return true
}
func (s *Server) fastRequestHandler(ctx *fasthttp.RequestCtx) {
method := string(ctx.Method())
path := string(ctx.Path())
// Fast route lookup
handlerID, params, found := s.router.Lookup(method, path)
if !found {
ctx.SetStatusCode(404)
ctx.SetBodyString("Not Found")
return
}
// Get compiled handler
s.handlersMu.RLock()
handler := s.handlers[handlerID]
s.handlersMu.RUnlock()
if handler == nil {
ctx.SetStatusCode(500)
ctx.SetBodyString("Handler not found")
return
}
// Lock state for execution
s.stateMu.Lock()
defer s.stateMu.Unlock()
// Setup request context
reqCtx := &RequestContext{
ctx: ctx,
params: params,
session: s.sessions.GetSessionFromRequest(ctx),
}
reqCtx.session.AdvanceFlash()
var responseBody string
if handler.isFunction {
// Function handler - use traditional approach
s.setupFunctionEnvironment(s.state, reqCtx)
if !s.getFunction(s.state, handler.funcRef) {
ctx.SetStatusCode(500)
ctx.SetBodyString("Function handler not found")
return
}
// Push request object
if err := s.state.PushValue(s.requestToTable(reqCtx)); err != nil {
ctx.SetStatusCode(500)
ctx.SetBodyString("Failed to create request object")
return
}
if err := s.state.Call(1, 1); err != nil {
ctx.SetStatusCode(500)
ctx.SetBodyString(fmt.Sprintf("Handler error: %v", err))
return
}
if s.state.GetTop() > 0 && !s.state.IsNil(-1) {
responseBody = s.state.ToString(-1)
}
s.state.Pop(1)
} else {
// Bytecode handler - use fast approach
s.setupFastEnvironment(s.state, reqCtx)
if err := s.state.LoadAndRunBytecode(handler.bytecode, handler.name); err != nil {
ctx.SetStatusCode(500)
ctx.SetBodyString(fmt.Sprintf("Handler error: %v", err))
return
}
if s.state.GetTop() > 0 && !s.state.IsNil(-1) {
responseBody = s.state.ToString(-1)
}
s.state.Pop(1)
}
// Apply response
s.applyResponse(ctx, s.state, responseBody, reqCtx.session)
s.sessions.ApplySessionCookie(ctx, reqCtx.session)
// Clean up state
s.state.SetTop(0)
}
func (s *Server) setupFunctionEnvironment(state *luajit.State, reqCtx *RequestContext) {
// Set up response globals for function handlers
state.NewTable()
state.PushNumber(200)
state.SetField(-2, "status")
state.NewTable()
state.SetField(-2, "headers")
state.NewTable()
state.SetField(-2, "cookies")
state.SetGlobal("__response")
// Session data
if !reqCtx.session.IsEmpty() {
state.PushValue(reqCtx.session.GetAll())
state.SetGlobal("__session")
}
}
func (s *Server) setupFastEnvironment(state *luajit.State, reqCtx *RequestContext) {
// Request basics as globals for fast access
state.PushString(string(reqCtx.ctx.Method()))
state.SetGlobal("REQUEST_METHOD")
state.PushString(string(reqCtx.ctx.Path()))
state.SetGlobal("REQUEST_PATH")
// Parameters
if reqCtx.params != nil && len(reqCtx.params.Keys) > 0 {
paramMap := make(map[string]string, len(reqCtx.params.Keys))
for i, key := range reqCtx.params.Keys {
if i < len(reqCtx.params.Values) {
paramMap[key] = reqCtx.params.Values[i]
}
}
state.PushValue(paramMap)
state.SetGlobal("PARAMS")
}
// Query parameters
queryMap := make(map[string]string)
reqCtx.ctx.QueryArgs().VisitAll(func(key, value []byte) {
queryMap[string(key)] = string(value)
})
if len(queryMap) > 0 {
state.PushValue(queryMap)
state.SetGlobal("QUERY")
}
// Headers
headerMap := make(map[string]string)
reqCtx.ctx.Request.Header.VisitAll(func(key, value []byte) {
headerMap[string(key)] = string(value)
})
state.PushValue(headerMap)
state.SetGlobal("HEADERS")
// Cookies
cookieMap := make(map[string]string)
reqCtx.ctx.Request.Header.VisitAllCookie(func(key, value []byte) {
cookieMap[string(key)] = string(value)
})
if len(cookieMap) > 0 {
state.PushValue(cookieMap)
state.SetGlobal("COOKIES")
}
// Form data
if reqCtx.ctx.IsPost() || reqCtx.ctx.IsPut() || reqCtx.ctx.IsPatch() {
form := s.parseForm(reqCtx.ctx)
if len(form) > 0 {
state.PushValue(form)
state.SetGlobal("FORM")
}
}
// Session data
if !reqCtx.session.IsEmpty() {
state.PushValue(reqCtx.session.GetAll())
state.SetGlobal("session_data")
}
// CSRF token
if csrfToken := s.generateCSRFToken(); csrfToken != "" {
state.PushString(csrfToken)
state.SetGlobal("CSRF_TOKEN")
}
// JSON encode fallback
state.RegisterGoFunction("json_encode_fallback", func(state *luajit.State) int {
val, _ := state.ToValue(1)
if b, err := json.Marshal(val); err == nil {
state.PushString(string(b))
} else {
state.PushString("null")
}
return 1
})
}
func (s *Server) requestToTable(reqCtx *RequestContext) map[string]any {
req := map[string]any{
"method": string(reqCtx.ctx.Method()),
"path": string(reqCtx.ctx.Path()),
"headers": make(map[string]string),
"query": make(map[string]string),
"cookies": make(map[string]string),
"body": string(reqCtx.ctx.PostBody()),
}
// Headers
headers := req["headers"].(map[string]string)
reqCtx.ctx.Request.Header.VisitAll(func(key, value []byte) {
headers[string(key)] = string(value)
})
// Cookies
cookies := req["cookies"].(map[string]string)
reqCtx.ctx.Request.Header.VisitAllCookie(func(key, value []byte) {
cookies[string(key)] = string(value)
})
// Query
query := req["query"].(map[string]string)
reqCtx.ctx.QueryArgs().VisitAll(func(key, value []byte) {
query[string(key)] = string(value)
})
// Params
if reqCtx.params != nil && len(reqCtx.params.Keys) > 0 {
params := make(map[string]string, len(reqCtx.params.Keys))
for i, key := range reqCtx.params.Keys {
if i < len(reqCtx.params.Values) {
params[key] = reqCtx.params.Values[i]
}
}
req["params"] = params
}
// Form
if reqCtx.ctx.IsPost() || reqCtx.ctx.IsPut() || reqCtx.ctx.IsPatch() {
req["form"] = s.parseForm(reqCtx.ctx)
}
return req
}
func (s *Server) applyResponse(ctx *fasthttp.RequestCtx, state *luajit.State, body string, session *sessions.Session) {
// Update session from Lua
state.GetGlobal("session_data")
if state.IsTable(-1) {
if data, err := state.ToTable(-1); err == nil {
if dataMap, ok := data.(map[string]any); ok {
session.Clear()
for k, v := range dataMap {
session.Set(k, v)
}
}
}
}
state.Pop(1)
// Check for response table (function handlers) or response global (fast handlers)
state.GetGlobal("__response")
if state.IsNil(-1) {
state.Pop(1)
state.GetGlobal("response")
if state.IsNil(-1) {
state.Pop(1)
if body != "" {
ctx.SetBodyString(body)
}
return
}
}
// Status
if status := state.GetFieldNumber(-1, "status", 200); status != 200 {
ctx.SetStatusCode(int(status))
}
// Headers
state.GetField(-1, "headers")
if state.IsTable(-1) {
state.ForEachTableKV(-1, func(key, value string) bool {
ctx.Response.Header.Set(key, value)
return true
})
}
state.Pop(1)
// Cookies
state.GetField(-1, "cookies")
if state.IsTable(-1) {
s.applyCookies(ctx, state)
}
state.Pop(1)
state.Pop(1)
if body != "" {
ctx.SetBodyString(body)
}
}
func (s *Server) applyCookies(ctx *fasthttp.RequestCtx, state *luajit.State) {
state.ForEachArray(-1, func(i int, st *luajit.State) bool {
if !st.IsTable(-1) {
return true
}
name := st.GetFieldString(-1, "name", "")
value := st.GetFieldString(-1, "value", "")
if name == "" {
return true
}
cookie := fasthttp.AcquireCookie()
defer fasthttp.ReleaseCookie(cookie)
cookie.SetKey(name)
cookie.SetValue(value)
if options, ok := st.GetFieldTable(-1, "options"); ok {
if optMap, ok := options.(map[string]any); ok {
if path, ok := optMap["path"].(string); ok {
cookie.SetPath(path)
} else {
cookie.SetPath("/")
}
if domain, ok := optMap["domain"].(string); ok {
cookie.SetDomain(domain)
}
if secure, ok := optMap["secure"].(bool); ok && secure {
cookie.SetSecure(true)
}
if httpOnly, ok := optMap["http_only"].(bool); ok {
cookie.SetHTTPOnly(httpOnly)
} else {
cookie.SetHTTPOnly(true)
}
if maxAge, ok := optMap["max_age"].(int); ok && maxAge > 0 {
cookie.SetExpire(time.Now().Add(time.Duration(maxAge) * time.Second))
}
}
}
ctx.Response.Header.SetCookie(cookie)
return true
})
}
func (s *Server) parseForm(ctx *fasthttp.RequestCtx) map[string]any {
contentType := string(ctx.Request.Header.ContentType())
form := make(map[string]any)
if strings.Contains(contentType, "application/json") {
var data map[string]any
if err := json.Unmarshal(ctx.PostBody(), &data); err == nil {
return data
}
} else if strings.Contains(contentType, "application/x-www-form-urlencoded") {
ctx.PostArgs().VisitAll(func(key, value []byte) {
form[string(key)] = string(value)
})
} else if strings.Contains(contentType, "multipart/form-data") {
if multipartForm, err := ctx.MultipartForm(); err == nil {
for key, values := range multipartForm.Value {
if len(values) == 1 {
form[key] = values[0]
} else {
form[key] = values
}
}
if len(multipartForm.File) > 0 {
files := make(map[string]any)
for fieldName, fileHeaders := range multipartForm.File {
if len(fileHeaders) == 1 {
files[fieldName] = s.fileToMap(fileHeaders[0])
} else {
fileList := make([]map[string]any, len(fileHeaders))
for i, fh := range fileHeaders {
fileList[i] = s.fileToMap(fh)
}
files[fieldName] = fileList
}
}
form["_files"] = files
}
}
}
return form
}
func (s *Server) fileToMap(fh *multipart.FileHeader) map[string]any {
return map[string]any{
"filename": fh.Filename,
"size": fh.Size,
"mimetype": fh.Header.Get("Content-Type"),
}
}
func (s *Server) generateCSRFToken() string {
bytes := make([]byte, 32)
rand.Read(bytes)
return base64.URLEncoding.EncodeToString(bytes)
}