Moonshark/http/http.go
2025-07-14 16:03:02 -05:00

757 lines
16 KiB
Go

package http
import (
"crypto/rand"
_ "embed"
"encoding/base64"
"fmt"
"mime/multipart"
"strings"
"sync"
"time"
"Moonshark/http/router"
"Moonshark/http/sessions"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/goccy/go-json"
"github.com/valyala/fasthttp"
)
//go:embed http.lua
var httpLuaCode string
type Server struct {
server *fasthttp.Server
router *router.Router
sessions *sessions.SessionManager
state *luajit.State
stateMu sync.Mutex
funcCounter int
}
type RequestContext struct {
Method string
Path string
Headers map[string]string
Query map[string]string
Form map[string]any
Cookies map[string]string
Session *sessions.Session
Body string
Params map[string]string
}
var globalServer *Server
func NewServer(state *luajit.State) *Server {
return &Server{
router: router.New(),
sessions: sessions.NewSessionManager(10000),
state: state,
}
}
func RegisterHTTPFunctions(L *luajit.State) error {
globalServer = NewServer(L)
functions := map[string]luajit.GoFunction{
"__http_listen": globalServer.httpListen,
"__http_route": globalServer.httpRoute,
"__http_set_status": httpSetStatus,
"__http_set_header": httpSetHeader,
"__http_redirect": httpRedirect,
"__session_get": globalServer.sessionGet,
"__session_set": globalServer.sessionSet,
"__session_flash": globalServer.sessionFlash,
"__session_get_flash": globalServer.sessionGetFlash,
"__cookie_set": cookieSet,
"__cookie_get": cookieGet,
"__csrf_generate": globalServer.csrfGenerate,
"__csrf_validate": globalServer.csrfValidate,
}
for name, fn := range functions {
if err := L.RegisterGoFunction(name, fn); err != nil {
return err
}
}
return L.DoString(httpLuaCode)
}
func (s *Server) httpListen(state *luajit.State) int {
port, err := state.SafeToNumber(1)
if err != nil {
return state.PushError("listen: port must be number")
}
s.server = &fasthttp.Server{
Handler: s.requestHandler,
Name: "Moonshark/1.0",
}
addr := fmt.Sprintf(":%d", int(port))
go func() {
if err := s.server.ListenAndServe(addr); err != nil {
fmt.Printf("Server error: %v\n", err)
}
}()
fmt.Printf("Server listening on port %d\n", int(port))
state.PushBoolean(true)
return 1
}
func (s *Server) httpRoute(state *luajit.State) int {
method, err := state.SafeToString(1)
if err != nil {
return state.PushError("route: method must be string")
}
path, err := state.SafeToString(2)
if err != nil {
return state.PushError("route: path must be string")
}
if !state.IsFunction(3) {
return state.PushError("route: handler must be function")
}
// Store function and get reference
state.PushCopy(3)
funcRef := s.storeFunction()
// Add route to router
if err := s.router.AddRoute(strings.ToUpper(method), path, funcRef); err != nil {
return state.PushError("route: failed to add route: %s", err.Error())
}
state.PushBoolean(true)
return 1
}
func (s *Server) storeFunction() int {
s.state.GetGlobal("__moonshark_functions")
if s.state.IsNil(-1) {
s.state.Pop(1)
s.state.NewTable()
s.state.PushCopy(-1)
s.state.SetGlobal("__moonshark_functions")
}
s.funcCounter++
s.state.PushNumber(float64(s.funcCounter))
s.state.PushCopy(-3)
s.state.SetTable(-3)
s.state.Pop(2)
return s.funcCounter
}
func (s *Server) getFunction(ref int) bool {
s.state.GetGlobal("__moonshark_functions")
if s.state.IsNil(-1) {
s.state.Pop(1)
return false
}
s.state.PushNumber(float64(ref))
s.state.GetTable(-2)
isFunc := s.state.IsFunction(-1)
if !isFunc {
s.state.Pop(2)
return false
}
s.state.Remove(-2)
return true
}
func (s *Server) requestHandler(ctx *fasthttp.RequestCtx) {
method := string(ctx.Method())
path := string(ctx.Path())
// Look up route in router
handlerRef, params, found := s.router.Lookup(method, path)
if !found {
ctx.SetStatusCode(404)
ctx.SetBodyString("Not Found")
return
}
reqCtx := s.buildRequestContext(ctx, params)
reqCtx.Session.AdvanceFlash()
s.stateMu.Lock()
defer s.stateMu.Unlock()
s.setupRequestEnvironment(reqCtx)
if !s.getFunction(handlerRef) {
ctx.SetStatusCode(500)
ctx.SetBodyString("Handler not found")
return
}
if err := s.state.PushValue(s.requestToTable(reqCtx)); err != nil {
ctx.SetStatusCode(500)
ctx.SetBodyString("Failed to create request object")
return
}
if err := s.state.Call(1, 1); err != nil {
ctx.SetStatusCode(500)
ctx.SetBodyString(fmt.Sprintf("Handler error: %v", err))
return
}
var responseBody string
if s.state.GetTop() > 0 && !s.state.IsNil(-1) {
responseBody = s.state.ToString(-1)
s.state.Pop(1)
}
s.updateSessionFromLua(reqCtx.Session)
s.applyResponse(ctx, responseBody)
s.sessions.ApplySessionCookie(ctx, reqCtx.Session)
}
func (s *Server) setupRequestEnvironment(reqCtx *RequestContext) {
s.state.PushValue(s.requestToTable(reqCtx))
s.state.SetGlobal("__request")
s.state.PushValue(s.sessionToTable(reqCtx.Session))
s.state.SetGlobal("__session")
s.state.NewTable()
s.state.SetGlobal("__response")
}
func (s *Server) requestToTable(reqCtx *RequestContext) map[string]any {
return map[string]any{
"method": reqCtx.Method,
"path": reqCtx.Path,
"headers": reqCtx.Headers,
"query": reqCtx.Query,
"form": reqCtx.Form,
"cookies": reqCtx.Cookies,
"body": reqCtx.Body,
"params": reqCtx.Params,
}
}
func (s *Server) sessionToTable(session *sessions.Session) map[string]any {
return map[string]any{
"id": session.ID,
"data": session.GetAll(),
}
}
func (s *Server) updateSessionFromLua(session *sessions.Session) {
s.state.GetGlobal("__session")
if s.state.IsNil(-1) {
s.state.Pop(1)
return
}
s.state.GetField(-1, "data")
if s.state.IsTable(-1) {
if data, err := s.state.ToTable(-1); err == nil {
if dataMap, ok := data.(map[string]any); ok {
session.Clear()
for k, v := range dataMap {
session.Set(k, v)
}
}
}
}
s.state.Pop(2)
}
func (s *Server) applyResponse(ctx *fasthttp.RequestCtx, body string) {
s.state.GetGlobal("__response")
if s.state.IsNil(-1) {
s.state.Pop(1)
if body != "" {
ctx.SetBodyString(body)
}
return
}
s.state.GetField(-1, "status")
if s.state.IsNumber(-1) {
ctx.SetStatusCode(int(s.state.ToNumber(-1)))
}
s.state.Pop(1)
s.state.GetField(-1, "headers")
if s.state.IsTable(-1) {
s.state.ForEachTableKV(-1, func(key, value string) bool {
ctx.Response.Header.Set(key, value)
return true
})
}
s.state.Pop(1)
s.state.GetField(-1, "cookies")
if s.state.IsTable(-1) {
s.applyCookies(ctx)
}
s.state.Pop(1)
s.state.Pop(1)
if body != "" {
ctx.SetBodyString(body)
}
}
func (s *Server) applyCookies(ctx *fasthttp.RequestCtx) {
s.state.ForEachArray(-1, func(i int, state *luajit.State) bool {
if !state.IsTable(-1) {
return true
}
name := state.GetFieldString(-1, "name", "")
value := state.GetFieldString(-1, "value", "")
if name == "" {
return true
}
cookie := fasthttp.AcquireCookie()
defer fasthttp.ReleaseCookie(cookie)
cookie.SetKey(name)
cookie.SetValue(value)
cookie.SetPath(state.GetFieldString(-1, "path", "/"))
if domain := state.GetFieldString(-1, "domain", ""); domain != "" {
cookie.SetDomain(domain)
}
if state.GetFieldBool(-1, "secure", false) {
cookie.SetSecure(true)
}
if state.GetFieldBool(-1, "http_only", true) {
cookie.SetHTTPOnly(true)
}
if maxAge := state.GetFieldNumber(-1, "max_age", 0); maxAge > 0 {
cookie.SetExpire(time.Now().Add(time.Duration(maxAge) * time.Second))
}
ctx.Response.Header.SetCookie(cookie)
return true
})
}
func (s *Server) buildRequestContext(ctx *fasthttp.RequestCtx, params *router.Params) *RequestContext {
reqCtx := &RequestContext{
Method: string(ctx.Method()),
Path: string(ctx.Path()),
Headers: make(map[string]string),
Query: make(map[string]string),
Cookies: make(map[string]string),
Body: string(ctx.PostBody()),
Params: make(map[string]string),
}
ctx.Request.Header.VisitAll(func(key, value []byte) {
reqCtx.Headers[string(key)] = string(value)
})
ctx.Request.Header.VisitAllCookie(func(key, value []byte) {
reqCtx.Cookies[string(key)] = string(value)
})
ctx.QueryArgs().VisitAll(func(key, value []byte) {
reqCtx.Query[string(key)] = string(value)
})
// Convert router params to map
for i, key := range params.Keys {
if i < len(params.Values) {
reqCtx.Params[key] = params.Values[i]
}
}
reqCtx.Form = s.parseForm(ctx)
reqCtx.Session = s.sessions.GetSessionFromRequest(ctx)
return reqCtx
}
func (s *Server) parseForm(ctx *fasthttp.RequestCtx) map[string]any {
contentType := string(ctx.Request.Header.ContentType())
form := make(map[string]any)
if strings.Contains(contentType, "application/json") {
var data map[string]any
if err := json.Unmarshal(ctx.PostBody(), &data); err == nil {
return data
}
} else if strings.Contains(contentType, "application/x-www-form-urlencoded") {
ctx.PostArgs().VisitAll(func(key, value []byte) {
form[string(key)] = string(value)
})
} else if strings.Contains(contentType, "multipart/form-data") {
if multipartForm, err := ctx.MultipartForm(); err == nil {
for key, values := range multipartForm.Value {
if len(values) == 1 {
form[key] = values[0]
} else {
form[key] = values
}
}
if len(multipartForm.File) > 0 {
files := make(map[string]any)
for fieldName, fileHeaders := range multipartForm.File {
if len(fileHeaders) == 1 {
files[fieldName] = s.fileToMap(fileHeaders[0])
} else {
fileList := make([]map[string]any, len(fileHeaders))
for i, fh := range fileHeaders {
fileList[i] = s.fileToMap(fh)
}
files[fieldName] = fileList
}
}
form["_files"] = files
}
}
}
return form
}
func (s *Server) fileToMap(fh *multipart.FileHeader) map[string]any {
return map[string]any{
"filename": fh.Filename,
"size": fh.Size,
"mimetype": fh.Header.Get("Content-Type"),
}
}
func (s *Server) generateCSRFToken() string {
bytes := make([]byte, 32)
rand.Read(bytes)
return base64.URLEncoding.EncodeToString(bytes)
}
// Lua function implementations
func httpSetStatus(state *luajit.State) int {
code, _ := state.SafeToNumber(1)
state.GetGlobal("__response")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.SetGlobal("__response")
state.GetGlobal("__response")
}
state.PushNumber(code)
state.SetField(-2, "status")
state.Pop(1)
return 0
}
func httpSetHeader(state *luajit.State) int {
name, _ := state.SafeToString(1)
value, _ := state.SafeToString(2)
state.GetGlobal("__response")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.SetGlobal("__response")
state.GetGlobal("__response")
}
state.GetField(-1, "headers")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.PushCopy(-1)
state.SetField(-3, "headers")
}
state.PushString(value)
state.SetField(-2, name)
state.Pop(2)
return 0
}
func httpRedirect(state *luajit.State) int {
url, _ := state.SafeToString(1)
status := 302.0
if state.GetTop() >= 2 {
status, _ = state.SafeToNumber(2)
}
state.GetGlobal("__response")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.SetGlobal("__response")
state.GetGlobal("__response")
}
state.PushNumber(status)
state.SetField(-2, "status")
state.GetField(-1, "headers")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.PushCopy(-1)
state.SetField(-3, "headers")
}
state.PushString(url)
state.SetField(-2, "Location")
state.Pop(2)
return 0
}
func (s *Server) sessionGet(state *luajit.State) int {
key, _ := state.SafeToString(1)
state.GetGlobal("__session")
if state.IsNil(-1) {
state.Pop(1)
state.PushNil()
return 1
}
state.GetField(-1, "data")
if state.IsNil(-1) {
state.Pop(2)
state.PushNil()
return 1
}
state.GetField(-1, key)
return 1
}
func (s *Server) sessionSet(state *luajit.State) int {
key, _ := state.SafeToString(1)
state.GetGlobal("__session")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.SetGlobal("__session")
state.GetGlobal("__session")
}
state.GetField(-1, "data")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.PushCopy(-1)
state.SetField(-3, "data")
}
value, err := state.ToValue(2)
if err == nil {
state.PushValue(value)
state.SetField(-2, key)
}
state.Pop(2)
return 0
}
func (s *Server) sessionFlash(state *luajit.State) int {
key, _ := state.SafeToString(1)
state.GetGlobal("__session")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.SetGlobal("__session")
state.GetGlobal("__session")
}
state.GetField(-1, "flash")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.PushCopy(-1)
state.SetField(-3, "flash")
}
value, err := state.ToValue(2)
if err == nil {
state.PushValue(value)
state.SetField(-2, key)
}
state.Pop(2)
return 0
}
func (s *Server) sessionGetFlash(state *luajit.State) int {
key, _ := state.SafeToString(1)
state.GetGlobal("__session")
if state.IsNil(-1) {
state.Pop(1)
state.PushNil()
return 1
}
state.GetField(-1, "flash")
if state.IsNil(-1) {
state.Pop(2)
state.PushNil()
return 1
}
state.GetField(-1, key)
return 1
}
func cookieSet(state *luajit.State) int {
name, _ := state.SafeToString(1)
value, _ := state.SafeToString(2)
maxAge := 0
path := "/"
domain := ""
secure := false
httpOnly := true
if state.GetTop() >= 3 && state.IsTable(3) {
maxAge = int(state.GetFieldNumber(3, "max_age", 0))
path = state.GetFieldString(3, "path", "/")
domain = state.GetFieldString(3, "domain", "")
secure = state.GetFieldBool(3, "secure", false)
httpOnly = state.GetFieldBool(3, "http_only", true)
}
state.GetGlobal("__response")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.SetGlobal("__response")
state.GetGlobal("__response")
}
state.GetField(-1, "cookies")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.PushCopy(-1)
state.SetField(-3, "cookies")
}
cookieData := map[string]any{
"name": name,
"value": value,
"path": path,
"secure": secure,
"http_only": httpOnly,
}
if domain != "" {
cookieData["domain"] = domain
}
if maxAge > 0 {
cookieData["max_age"] = maxAge
}
state.PushValue(cookieData)
length := globalServer.getTableLength(-2)
state.PushNumber(float64(length + 1))
state.PushCopy(-2)
state.SetTable(-4)
state.Pop(3)
return 0
}
func cookieGet(state *luajit.State) int {
name, _ := state.SafeToString(1)
state.GetGlobal("__request")
if state.IsNil(-1) {
state.Pop(1)
state.PushNil()
return 1
}
state.GetField(-1, "cookies")
if state.IsNil(-1) {
state.Pop(2)
state.PushNil()
return 1
}
state.GetField(-1, name)
return 1
}
func (s *Server) csrfGenerate(state *luajit.State) int {
token := s.generateCSRFToken()
state.GetGlobal("__session")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.SetGlobal("__session")
state.GetGlobal("__session")
}
state.GetField(-1, "data")
if state.IsNil(-1) {
state.Pop(1)
state.NewTable()
state.PushCopy(-1)
state.SetField(-3, "data")
}
state.PushString(token)
state.SetField(-2, "_csrf_token")
state.Pop(2)
state.PushString(token)
return 1
}
func (s *Server) csrfValidate(state *luajit.State) int {
state.GetGlobal("__session")
if state.IsNil(-1) {
state.Pop(1)
state.PushBoolean(false)
return 1
}
sessionToken := state.GetFieldString(-1, "data._csrf_token", "")
state.Pop(1)
state.GetGlobal("__request")
if state.IsNil(-1) {
state.Pop(1)
state.PushBoolean(false)
return 1
}
requestToken := state.GetFieldString(-1, "form._csrf_token", "")
state.Pop(1)
state.PushBoolean(sessionToken != "" && sessionToken == requestToken)
return 1
}
func (s *Server) getTableLength(index int) int {
length := 0
s.state.PushNil()
for s.state.Next(index - 1) {
length++
s.state.Pop(1)
}
return length
}