diff --git a/config-example b/config-example deleted file mode 100644 index d200871..0000000 --- a/config-example +++ /dev/null @@ -1,21 +0,0 @@ -server { - port 3117 - debug false - http_logging false - static_prefix "public" -} - -runner { - pool_size 0 -- 0 defaults to GOMAXPROCS -} - -dirs = { - routes "routes" - static "public" - fs "fs" - data "data" - override "override" - libs { - "libs" - } -} diff --git a/config/config.go b/config/config.go deleted file mode 100644 index f8a065d..0000000 --- a/config/config.go +++ /dev/null @@ -1,55 +0,0 @@ -package config - -import ( - "runtime" - - fin "git.sharkk.net/Sharkk/Fin" -) - -// Config represents the current loaded configuration for the server -type Config struct { - Server struct { - Port int - Debug bool - HTTPLogging bool - PublicPrefix string - } - - Runner struct { - PoolSize int - } - - Dirs struct { - Routes string - Public string - FS string - Data string - Override string - Libs []string - } - - data *fin.Data -} - -// NewConfig creates a new configuration with default values -func New(data *fin.Data) *Config { - config := &Config{ - data: data, - } - - config.Server.Port = data.GetOr("server.port", 3117).(int) - config.Server.Debug = data.GetOr("server.debug", false).(bool) - config.Server.HTTPLogging = data.GetOr("server.http_logging", true).(bool) - config.Server.PublicPrefix = data.GetOr("server.public_prefix", "public").(string) - - config.Runner.PoolSize = data.GetOr("runner.pool_size", runtime.GOMAXPROCS(0)).(int) - - config.Dirs.Routes = data.GetOr("dirs.routes", "routes").(string) - config.Dirs.Public = data.GetOr("dirs.public", "public").(string) - config.Dirs.FS = data.GetOr("dirs.fs", "fs").(string) - config.Dirs.Data = data.GetOr("dirs.data", "data").(string) - config.Dirs.Override = data.GetOr("dirs.override", "override").(string) - config.Dirs.Libs = data.GetOr("dirs.libs", []string{"libs"}).([]string) - - return config -} diff --git a/go.mod b/go.mod index 157815b..3a67a5a 100644 --- a/go.mod +++ b/go.mod @@ -3,35 +3,20 @@ module Moonshark go 1.24.1 require ( - git.sharkk.net/Go/Color v1.1.0 - git.sharkk.net/Go/LRU v1.0.0 - git.sharkk.net/Sharkk/Fin v1.3.0 - git.sharkk.net/Sky/LuaJIT-to-Go v0.5.3 - github.com/VictoriaMetrics/fastcache v1.12.4 - github.com/alexedwards/argon2id v1.0.0 - github.com/deneonet/benc v1.1.8 + git.sharkk.net/Sky/LuaJIT-to-Go v0.5.4 github.com/goccy/go-json v0.10.5 - github.com/golang/snappy v1.0.0 - github.com/matoous/go-nanoid/v2 v2.1.0 - github.com/valyala/bytebufferpool v1.0.0 - github.com/valyala/fasthttp v1.62.0 - zombiezen.com/go/sqlite v1.4.2 + github.com/valyala/fasthttp v1.63.0 ) require ( - github.com/andybalholm/brotli v1.1.1 // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/dustin/go-humanize v1.0.1 // indirect - github.com/google/uuid v1.6.0 // indirect + github.com/VictoriaMetrics/fastcache v1.12.5 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/deneonet/benc v1.1.8 // indirect + github.com/golang/snappy v0.0.4 // indirect github.com/klauspost/compress v1.18.0 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect - github.com/ncruces/go-strftime v0.1.9 // indirect - github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - golang.org/x/crypto v0.38.0 // indirect - golang.org/x/exp v0.0.0-20250531010427-b6e5de432a8b // indirect + github.com/matoous/go-nanoid/v2 v2.1.0 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect golang.org/x/sys v0.33.0 // indirect - modernc.org/libc v1.65.8 // indirect - modernc.org/mathutil v1.7.1 // indirect - modernc.org/memory v1.11.0 // indirect - modernc.org/sqlite v1.37.1 // indirect ) diff --git a/go.sum b/go.sum index ae66c7f..2d3bd89 100644 --- a/go.sum +++ b/go.sum @@ -1,133 +1,32 @@ -git.sharkk.net/Go/Color v1.1.0 h1:1eyUwlcerJLo9/42GSnQxOY84/Htdwz/QTu3FGgLEXk= -git.sharkk.net/Go/Color v1.1.0/go.mod h1:cyZFLbUh+GkpsIABVxb5w9EZM+FPj+q9GkCsoECaeTI= -git.sharkk.net/Go/LRU v1.0.0 h1:/KqdRVhHldi23aVfQZ4ss6vhCWZqA3vFiQyf1MJPpQc= -git.sharkk.net/Go/LRU v1.0.0/go.mod h1:8tdTyl85mss9a+KKwo+Wj9gKHOizhfLfpJhz1ltYz50= -git.sharkk.net/Sharkk/Fin v1.3.0 h1:6/f7+h382jJOeo09cgdzH+PGb5XdvajZZRiES52sBkI= -git.sharkk.net/Sharkk/Fin v1.3.0/go.mod h1:ca0Ej9yCM/vHh1o3YMvBZspme3EtbwoEL2UXN5UPXMo= -git.sharkk.net/Sky/LuaJIT-to-Go v0.5.3 h1:SuLz4X/k+sMy+Uj1lhEy6brJtvtzHLdivUcu5K91y+o= -git.sharkk.net/Sky/LuaJIT-to-Go v0.5.3/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8= -github.com/VictoriaMetrics/fastcache v1.12.4 h1:2xvmwZBW+9QtHsXggfzAZRs1FZWCsBs8QDg22bMidf0= -github.com/VictoriaMetrics/fastcache v1.12.4/go.mod h1:K+JGPBn0sueFlLjZ8rcVM0cKkWKNElKyQXmw57QOoYI= -github.com/alexedwards/argon2id v1.0.0 h1:wJzDx66hqWX7siL/SRUmgz3F8YMrd/nfX/xHHcQQP0w= -github.com/alexedwards/argon2id v1.0.0/go.mod h1:tYKkqIjzXvZdzPvADMWOEZ+l6+BD6CtBXMj5fnJppiw= -github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156 h1:eMwmnE/GDgah4HI848JfFxHt+iPb26b4zyfspmqY0/8= -github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM= +git.sharkk.net/Sky/LuaJIT-to-Go v0.5.4 h1:j83C2pzDBaVP4FPpyBzP4Dch61plzIko/t7zFvjOK2I= +git.sharkk.net/Sky/LuaJIT-to-Go v0.5.4/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8= +github.com/VictoriaMetrics/fastcache v1.12.5 h1:966OX9JjqYmDAFdp3wEXLwzukiHIm+GVlZHv6B8KW3k= +github.com/VictoriaMetrics/fastcache v1.12.5/go.mod h1:K+JGPBn0sueFlLjZ8rcVM0cKkWKNElKyQXmw57QOoYI= github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= -github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/deneonet/benc v1.1.8 h1:Qk9diyH0UcnduvCrZ62mBrwUeSZzte4kQxMbclVdhW4= github.com/deneonet/benc v1.1.8/go.mod h1:UCfkM5Od0B2huwv/ZItvtUb7QnALFt9YXtX8NXX4Lts= -github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= -github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= -github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= -github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/matoous/go-nanoid/v2 v2.1.0 h1:P64+dmq21hhWdtvZfEAofnvJULaRR1Yib0+PnU669bE= github.com/matoous/go-nanoid/v2 v2.1.0/go.mod h1:KlbGNQ+FhrUNIHUxZdL63t7tl4LaPkZNpUULS8H4uVM= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= -github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= -github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.62.0 h1:8dKRBX/y2rCzyc6903Zu1+3qN0H/d2MsxPPmVNamiH0= github.com/valyala/fasthttp v1.62.0/go.mod h1:FCINgr4GKdKqV8Q0xv8b+UxPV+H/O5nNFo3D+r54Htg= +github.com/valyala/fasthttp v1.63.0 h1:DisIL8OjB7ul2d7cBaMRcKTQDYnrGy56R4FCiuDP0Ns= +github.com/valyala/fasthttp v1.63.0/go.mod h1:REc4IeW+cAEyLrRPa5A81MIjvz0QE1laoTX2EaPHKJM= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= -golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= -golang.org/x/exp v0.0.0-20250531010427-b6e5de432a8b h1:QoALfVG9rhQ/M7vYDScfPdWjGL9dlsVVM5VGh7aKoAA= -golang.org/x/exp v0.0.0-20250531010427-b6e5de432a8b/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= -golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= -golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ= +golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= -golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= -golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -modernc.org/cc/v4 v4.26.1 h1:+X5NtzVBn0KgsBCBe+xkDC7twLb/jNVj9FPgiwSQO3s= -modernc.org/cc/v4 v4.26.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= -modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU= -modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE= -modernc.org/fileutil v1.3.1 h1:8vq5fe7jdtEvoCf3Zf9Nm0Q05sH6kGx0Op2CPx1wTC8= -modernc.org/fileutil v1.3.1/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= -modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= -modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= -modernc.org/libc v1.65.8 h1:7PXRJai0TXZ8uNA3srsmYzmTyrLoHImV5QxHeni108Q= -modernc.org/libc v1.65.8/go.mod h1:011EQibzzio/VX3ygj1qGFt5kMjP0lHb0qCW5/D/pQU= -modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= -modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= -modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= -modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= -modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= -modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= -modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= -modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= -modernc.org/sqlite v1.37.1 h1:EgHJK/FPoqC+q2YBXg7fUmES37pCHFc97sI7zSayBEs= -modernc.org/sqlite v1.37.1/go.mod h1:XwdRtsE1MpiBcL54+MbKcaDvcuej+IYSMfLN6gSKV8g= -modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= -modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= -modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= -modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= -zombiezen.com/go/sqlite v1.4.2 h1:KZXLrBuJ7tKNEm+VJcApLMeQbhmAUOKA5VWS93DfFRo= -zombiezen.com/go/sqlite v1.4.2/go.mod h1:5Kd4taTAD4MkBzT25mQ9uaAlLjyR0rFhsR6iINO70jc= diff --git a/http/http.go b/http/http.go new file mode 100644 index 0000000..b05b88c --- /dev/null +++ b/http/http.go @@ -0,0 +1,756 @@ +package http + +import ( + "crypto/rand" + _ "embed" + "encoding/base64" + "fmt" + "mime/multipart" + "strings" + "sync" + "time" + + "Moonshark/http/router" + "Moonshark/http/sessions" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" + "github.com/goccy/go-json" + "github.com/valyala/fasthttp" +) + +//go:embed http.lua +var httpLuaCode string + +type Server struct { + server *fasthttp.Server + router *router.Router + sessions *sessions.SessionManager + state *luajit.State + stateMu sync.Mutex + funcCounter int +} + +type RequestContext struct { + Method string + Path string + Headers map[string]string + Query map[string]string + Form map[string]any + Cookies map[string]string + Session *sessions.Session + Body string + Params map[string]string +} + +var globalServer *Server + +func NewServer(state *luajit.State) *Server { + return &Server{ + router: router.New(), + sessions: sessions.NewSessionManager(10000), + state: state, + } +} + +func RegisterHTTPFunctions(L *luajit.State) error { + globalServer = NewServer(L) + + functions := map[string]luajit.GoFunction{ + "__http_listen": globalServer.httpListen, + "__http_route": globalServer.httpRoute, + "__http_set_status": httpSetStatus, + "__http_set_header": httpSetHeader, + "__http_redirect": httpRedirect, + "__session_get": globalServer.sessionGet, + "__session_set": globalServer.sessionSet, + "__session_flash": globalServer.sessionFlash, + "__session_get_flash": globalServer.sessionGetFlash, + "__cookie_set": cookieSet, + "__cookie_get": cookieGet, + "__csrf_generate": globalServer.csrfGenerate, + "__csrf_validate": globalServer.csrfValidate, + } + + for name, fn := range functions { + if err := L.RegisterGoFunction(name, fn); err != nil { + return err + } + } + + return L.DoString(httpLuaCode) +} + +func (s *Server) httpListen(state *luajit.State) int { + port, err := state.SafeToNumber(1) + if err != nil { + return state.PushError("listen: port must be number") + } + + s.server = &fasthttp.Server{ + Handler: s.requestHandler, + Name: "Moonshark/1.0", + } + + addr := fmt.Sprintf(":%d", int(port)) + go func() { + if err := s.server.ListenAndServe(addr); err != nil { + fmt.Printf("Server error: %v\n", err) + } + }() + + fmt.Printf("Server listening on port %d\n", int(port)) + state.PushBoolean(true) + return 1 +} + +func (s *Server) httpRoute(state *luajit.State) int { + method, err := state.SafeToString(1) + if err != nil { + return state.PushError("route: method must be string") + } + + path, err := state.SafeToString(2) + if err != nil { + return state.PushError("route: path must be string") + } + + if !state.IsFunction(3) { + return state.PushError("route: handler must be function") + } + + // Store function and get reference + state.PushCopy(3) + funcRef := s.storeFunction() + + // Add route to router + if err := s.router.AddRoute(strings.ToUpper(method), path, funcRef); err != nil { + return state.PushError("route: failed to add route: %s", err.Error()) + } + + state.PushBoolean(true) + return 1 +} + +func (s *Server) storeFunction() int { + s.state.GetGlobal("__moonshark_functions") + if s.state.IsNil(-1) { + s.state.Pop(1) + s.state.NewTable() + s.state.PushCopy(-1) + s.state.SetGlobal("__moonshark_functions") + } + + s.funcCounter++ + s.state.PushNumber(float64(s.funcCounter)) + s.state.PushCopy(-3) + s.state.SetTable(-3) + s.state.Pop(2) + + return s.funcCounter +} + +func (s *Server) getFunction(ref int) bool { + s.state.GetGlobal("__moonshark_functions") + if s.state.IsNil(-1) { + s.state.Pop(1) + return false + } + + s.state.PushNumber(float64(ref)) + s.state.GetTable(-2) + isFunc := s.state.IsFunction(-1) + if !isFunc { + s.state.Pop(2) + return false + } + + s.state.Remove(-2) + return true +} + +func (s *Server) requestHandler(ctx *fasthttp.RequestCtx) { + method := string(ctx.Method()) + path := string(ctx.Path()) + + // Look up route in router + handlerRef, params, found := s.router.Lookup(method, path) + if !found { + ctx.SetStatusCode(404) + ctx.SetBodyString("Not Found") + return + } + + reqCtx := s.buildRequestContext(ctx, params) + reqCtx.Session.AdvanceFlash() + + s.stateMu.Lock() + defer s.stateMu.Unlock() + + s.setupRequestEnvironment(reqCtx) + + if !s.getFunction(handlerRef) { + ctx.SetStatusCode(500) + ctx.SetBodyString("Handler not found") + return + } + + if err := s.state.PushValue(s.requestToTable(reqCtx)); err != nil { + ctx.SetStatusCode(500) + ctx.SetBodyString("Failed to create request object") + return + } + + if err := s.state.Call(1, 1); err != nil { + ctx.SetStatusCode(500) + ctx.SetBodyString(fmt.Sprintf("Handler error: %v", err)) + return + } + + var responseBody string + if s.state.GetTop() > 0 && !s.state.IsNil(-1) { + responseBody = s.state.ToString(-1) + s.state.Pop(1) + } + + s.updateSessionFromLua(reqCtx.Session) + s.applyResponse(ctx, responseBody) + s.sessions.ApplySessionCookie(ctx, reqCtx.Session) +} + +func (s *Server) setupRequestEnvironment(reqCtx *RequestContext) { + s.state.PushValue(s.requestToTable(reqCtx)) + s.state.SetGlobal("__request") + + s.state.PushValue(s.sessionToTable(reqCtx.Session)) + s.state.SetGlobal("__session") + + s.state.NewTable() + s.state.SetGlobal("__response") +} + +func (s *Server) requestToTable(reqCtx *RequestContext) map[string]any { + return map[string]any{ + "method": reqCtx.Method, + "path": reqCtx.Path, + "headers": reqCtx.Headers, + "query": reqCtx.Query, + "form": reqCtx.Form, + "cookies": reqCtx.Cookies, + "body": reqCtx.Body, + "params": reqCtx.Params, + } +} + +func (s *Server) sessionToTable(session *sessions.Session) map[string]any { + return map[string]any{ + "id": session.ID, + "data": session.GetAll(), + } +} + +func (s *Server) updateSessionFromLua(session *sessions.Session) { + s.state.GetGlobal("__session") + if s.state.IsNil(-1) { + s.state.Pop(1) + return + } + + s.state.GetField(-1, "data") + if s.state.IsTable(-1) { + if data, err := s.state.ToTable(-1); err == nil { + if dataMap, ok := data.(map[string]any); ok { + session.Clear() + for k, v := range dataMap { + session.Set(k, v) + } + } + } + } + s.state.Pop(2) +} + +func (s *Server) applyResponse(ctx *fasthttp.RequestCtx, body string) { + s.state.GetGlobal("__response") + if s.state.IsNil(-1) { + s.state.Pop(1) + if body != "" { + ctx.SetBodyString(body) + } + return + } + + s.state.GetField(-1, "status") + if s.state.IsNumber(-1) { + ctx.SetStatusCode(int(s.state.ToNumber(-1))) + } + s.state.Pop(1) + + s.state.GetField(-1, "headers") + if s.state.IsTable(-1) { + s.state.ForEachTableKV(-1, func(key, value string) bool { + ctx.Response.Header.Set(key, value) + return true + }) + } + s.state.Pop(1) + + s.state.GetField(-1, "cookies") + if s.state.IsTable(-1) { + s.applyCookies(ctx) + } + s.state.Pop(1) + + s.state.Pop(1) + + if body != "" { + ctx.SetBodyString(body) + } +} + +func (s *Server) applyCookies(ctx *fasthttp.RequestCtx) { + s.state.ForEachArray(-1, func(i int, state *luajit.State) bool { + if !state.IsTable(-1) { + return true + } + + name := state.GetFieldString(-1, "name", "") + value := state.GetFieldString(-1, "value", "") + if name == "" { + return true + } + + cookie := fasthttp.AcquireCookie() + defer fasthttp.ReleaseCookie(cookie) + + cookie.SetKey(name) + cookie.SetValue(value) + cookie.SetPath(state.GetFieldString(-1, "path", "/")) + + if domain := state.GetFieldString(-1, "domain", ""); domain != "" { + cookie.SetDomain(domain) + } + + if state.GetFieldBool(-1, "secure", false) { + cookie.SetSecure(true) + } + + if state.GetFieldBool(-1, "http_only", true) { + cookie.SetHTTPOnly(true) + } + + if maxAge := state.GetFieldNumber(-1, "max_age", 0); maxAge > 0 { + cookie.SetExpire(time.Now().Add(time.Duration(maxAge) * time.Second)) + } + + ctx.Response.Header.SetCookie(cookie) + return true + }) +} + +func (s *Server) buildRequestContext(ctx *fasthttp.RequestCtx, params *router.Params) *RequestContext { + reqCtx := &RequestContext{ + Method: string(ctx.Method()), + Path: string(ctx.Path()), + Headers: make(map[string]string), + Query: make(map[string]string), + Cookies: make(map[string]string), + Body: string(ctx.PostBody()), + Params: make(map[string]string), + } + + ctx.Request.Header.VisitAll(func(key, value []byte) { + reqCtx.Headers[string(key)] = string(value) + }) + + ctx.Request.Header.VisitAllCookie(func(key, value []byte) { + reqCtx.Cookies[string(key)] = string(value) + }) + + ctx.QueryArgs().VisitAll(func(key, value []byte) { + reqCtx.Query[string(key)] = string(value) + }) + + // Convert router params to map + for i, key := range params.Keys { + if i < len(params.Values) { + reqCtx.Params[key] = params.Values[i] + } + } + + reqCtx.Form = s.parseForm(ctx) + reqCtx.Session = s.sessions.GetSessionFromRequest(ctx) + + return reqCtx +} + +func (s *Server) parseForm(ctx *fasthttp.RequestCtx) map[string]any { + contentType := string(ctx.Request.Header.ContentType()) + form := make(map[string]any) + + if strings.Contains(contentType, "application/json") { + var data map[string]any + if err := json.Unmarshal(ctx.PostBody(), &data); err == nil { + return data + } + } else if strings.Contains(contentType, "application/x-www-form-urlencoded") { + ctx.PostArgs().VisitAll(func(key, value []byte) { + form[string(key)] = string(value) + }) + } else if strings.Contains(contentType, "multipart/form-data") { + if multipartForm, err := ctx.MultipartForm(); err == nil { + for key, values := range multipartForm.Value { + if len(values) == 1 { + form[key] = values[0] + } else { + form[key] = values + } + } + + if len(multipartForm.File) > 0 { + files := make(map[string]any) + for fieldName, fileHeaders := range multipartForm.File { + if len(fileHeaders) == 1 { + files[fieldName] = s.fileToMap(fileHeaders[0]) + } else { + fileList := make([]map[string]any, len(fileHeaders)) + for i, fh := range fileHeaders { + fileList[i] = s.fileToMap(fh) + } + files[fieldName] = fileList + } + } + form["_files"] = files + } + } + } + + return form +} + +func (s *Server) fileToMap(fh *multipart.FileHeader) map[string]any { + return map[string]any{ + "filename": fh.Filename, + "size": fh.Size, + "mimetype": fh.Header.Get("Content-Type"), + } +} + +func (s *Server) generateCSRFToken() string { + bytes := make([]byte, 32) + rand.Read(bytes) + return base64.URLEncoding.EncodeToString(bytes) +} + +// Lua function implementations +func httpSetStatus(state *luajit.State) int { + code, _ := state.SafeToNumber(1) + state.GetGlobal("__response") + if state.IsNil(-1) { + state.Pop(1) + state.NewTable() + state.SetGlobal("__response") + state.GetGlobal("__response") + } + state.PushNumber(code) + state.SetField(-2, "status") + state.Pop(1) + return 0 +} + +func httpSetHeader(state *luajit.State) int { + name, _ := state.SafeToString(1) + value, _ := state.SafeToString(2) + + state.GetGlobal("__response") + if state.IsNil(-1) { + state.Pop(1) + state.NewTable() + state.SetGlobal("__response") + state.GetGlobal("__response") + } + + state.GetField(-1, "headers") + if state.IsNil(-1) { + state.Pop(1) + state.NewTable() + state.PushCopy(-1) + state.SetField(-3, "headers") + } + + state.PushString(value) + state.SetField(-2, name) + state.Pop(2) + return 0 +} + +func httpRedirect(state *luajit.State) int { + url, _ := state.SafeToString(1) + status := 302.0 + if state.GetTop() >= 2 { + status, _ = state.SafeToNumber(2) + } + + state.GetGlobal("__response") + if state.IsNil(-1) { + state.Pop(1) + state.NewTable() + state.SetGlobal("__response") + state.GetGlobal("__response") + } + + state.PushNumber(status) + state.SetField(-2, "status") + + state.GetField(-1, "headers") + if state.IsNil(-1) { + state.Pop(1) + state.NewTable() + state.PushCopy(-1) + state.SetField(-3, "headers") + } + + state.PushString(url) + state.SetField(-2, "Location") + state.Pop(2) + return 0 +} + +func (s *Server) sessionGet(state *luajit.State) int { + key, _ := state.SafeToString(1) + + state.GetGlobal("__session") + if state.IsNil(-1) { + state.Pop(1) + state.PushNil() + return 1 + } + + state.GetField(-1, "data") + if state.IsNil(-1) { + state.Pop(2) + state.PushNil() + return 1 + } + + state.GetField(-1, key) + return 1 +} + +func (s *Server) sessionSet(state *luajit.State) int { + key, _ := state.SafeToString(1) + + state.GetGlobal("__session") + if state.IsNil(-1) { + state.Pop(1) + state.NewTable() + state.SetGlobal("__session") + state.GetGlobal("__session") + } + + state.GetField(-1, "data") + if state.IsNil(-1) { + state.Pop(1) + state.NewTable() + state.PushCopy(-1) + state.SetField(-3, "data") + } + + value, err := state.ToValue(2) + if err == nil { + state.PushValue(value) + state.SetField(-2, key) + } + state.Pop(2) + return 0 +} + +func (s *Server) sessionFlash(state *luajit.State) int { + key, _ := state.SafeToString(1) + + state.GetGlobal("__session") + if state.IsNil(-1) { + state.Pop(1) + state.NewTable() + state.SetGlobal("__session") + state.GetGlobal("__session") + } + + state.GetField(-1, "flash") + if state.IsNil(-1) { + state.Pop(1) + state.NewTable() + state.PushCopy(-1) + state.SetField(-3, "flash") + } + + value, err := state.ToValue(2) + if err == nil { + state.PushValue(value) + state.SetField(-2, key) + } + state.Pop(2) + return 0 +} + +func (s *Server) sessionGetFlash(state *luajit.State) int { + key, _ := state.SafeToString(1) + + state.GetGlobal("__session") + if state.IsNil(-1) { + state.Pop(1) + state.PushNil() + return 1 + } + + state.GetField(-1, "flash") + if state.IsNil(-1) { + state.Pop(2) + state.PushNil() + return 1 + } + + state.GetField(-1, key) + return 1 +} + +func cookieSet(state *luajit.State) int { + name, _ := state.SafeToString(1) + value, _ := state.SafeToString(2) + + maxAge := 0 + path := "/" + domain := "" + secure := false + httpOnly := true + + if state.GetTop() >= 3 && state.IsTable(3) { + maxAge = int(state.GetFieldNumber(3, "max_age", 0)) + path = state.GetFieldString(3, "path", "/") + domain = state.GetFieldString(3, "domain", "") + secure = state.GetFieldBool(3, "secure", false) + httpOnly = state.GetFieldBool(3, "http_only", true) + } + + state.GetGlobal("__response") + if state.IsNil(-1) { + state.Pop(1) + state.NewTable() + state.SetGlobal("__response") + state.GetGlobal("__response") + } + + state.GetField(-1, "cookies") + if state.IsNil(-1) { + state.Pop(1) + state.NewTable() + state.PushCopy(-1) + state.SetField(-3, "cookies") + } + + cookieData := map[string]any{ + "name": name, + "value": value, + "path": path, + "secure": secure, + "http_only": httpOnly, + } + if domain != "" { + cookieData["domain"] = domain + } + if maxAge > 0 { + cookieData["max_age"] = maxAge + } + + state.PushValue(cookieData) + + length := globalServer.getTableLength(-2) + state.PushNumber(float64(length + 1)) + state.PushCopy(-2) + state.SetTable(-4) + + state.Pop(3) + return 0 +} + +func cookieGet(state *luajit.State) int { + name, _ := state.SafeToString(1) + + state.GetGlobal("__request") + if state.IsNil(-1) { + state.Pop(1) + state.PushNil() + return 1 + } + + state.GetField(-1, "cookies") + if state.IsNil(-1) { + state.Pop(2) + state.PushNil() + return 1 + } + + state.GetField(-1, name) + return 1 +} + +func (s *Server) csrfGenerate(state *luajit.State) int { + token := s.generateCSRFToken() + + state.GetGlobal("__session") + if state.IsNil(-1) { + state.Pop(1) + state.NewTable() + state.SetGlobal("__session") + state.GetGlobal("__session") + } + + state.GetField(-1, "data") + if state.IsNil(-1) { + state.Pop(1) + state.NewTable() + state.PushCopy(-1) + state.SetField(-3, "data") + } + + state.PushString(token) + state.SetField(-2, "_csrf_token") + state.Pop(2) + + state.PushString(token) + return 1 +} + +func (s *Server) csrfValidate(state *luajit.State) int { + state.GetGlobal("__session") + if state.IsNil(-1) { + state.Pop(1) + state.PushBoolean(false) + return 1 + } + + sessionToken := state.GetFieldString(-1, "data._csrf_token", "") + state.Pop(1) + + state.GetGlobal("__request") + if state.IsNil(-1) { + state.Pop(1) + state.PushBoolean(false) + return 1 + } + + requestToken := state.GetFieldString(-1, "form._csrf_token", "") + state.Pop(1) + + state.PushBoolean(sessionToken != "" && sessionToken == requestToken) + return 1 +} + +func (s *Server) getTableLength(index int) int { + length := 0 + s.state.PushNil() + for s.state.Next(index - 1) { + length++ + s.state.Pop(1) + } + return length +} diff --git a/http/http.lua b/http/http.lua new file mode 100644 index 0000000..d28cbe5 --- /dev/null +++ b/http/http.lua @@ -0,0 +1,115 @@ +http = {} + +function http.listen(port) + return __http_listen(port) +end + +function http.route(method, path, handler) + return __http_route(method, path, handler) +end + +function http.status(code) + return __http_set_status(code) +end + +function http.header(name, value) + return __http_set_header(name, value) +end + +function http.redirect(url, status) + __http_redirect(url, status or 302) + coroutine.yield() -- Exit handler +end + +function http.json(data) + http.header("Content-Type", "application/json") + return json.encode(data) +end + +function http.html(content) + http.header("Content-Type", "text/html") + return content +end + +function http.text(content) + http.header("Content-Type", "text/plain") + return content +end + +-- Session functions +session = {} + +function session.get(key) + return __session_get(key) +end + +function session.set(key, value) + return __session_set(key, value) +end + +function session.flash(key, value) + return __session_flash(key, value) +end + +function session.get_flash(key) + return __session_get_flash(key) +end + +-- Cookie functions +cookie = {} + +function cookie.set(name, value, options) + return __cookie_set(name, value, options) +end + +function cookie.get(name) + return __cookie_get(name) +end + +-- CSRF functions +csrf = {} + +function csrf.generate() + return __csrf_generate() +end + +function csrf.validate() + return __csrf_validate() +end + +function csrf.field() + local token = csrf.generate() + return string.format('', token) +end + +-- Helper functions +function redirect_with_flash(url, type, message) + session.flash(type, message) + http.redirect(url) +end + +-- JSON encoding/decoding placeholder +json = { + encode = function(data) + -- Simplified JSON encoding + if type(data) == "table" then + local result = "{" + local first = true + for k, v in pairs(data) do + if not first then result = result .. "," end + result = result .. '"' .. tostring(k) .. '":' .. json.encode(v) + first = false + end + return result .. "}" + elseif type(data) == "string" then + return '"' .. data .. '"' + else + return tostring(data) + end + end, + + decode = function(str) + -- Simplified JSON decoding - you'd want a proper implementation + return {} + end +} \ No newline at end of file diff --git a/http/router/router.go b/http/router/router.go new file mode 100644 index 0000000..c2f2354 --- /dev/null +++ b/http/router/router.go @@ -0,0 +1,260 @@ +package router + +import ( + "errors" + "strings" + "sync" +) + +// node represents a node in the radix trie +type node struct { + segment string + handler int // Lua function reference + children []*node + isDynamic bool // :param + isWildcard bool // *param + paramName string +} + +// Router is a string-based HTTP router with efficient lookup +type Router struct { + get, post, put, patch, delete *node + mu sync.RWMutex +} + +// Params holds URL parameters +type Params struct { + Keys []string + Values []string +} + +// Get returns a parameter value by name +func (p *Params) Get(name string) string { + for i, key := range p.Keys { + if key == name && i < len(p.Values) { + return p.Values[i] + } + } + return "" +} + +// New creates a new Router instance +func New() *Router { + return &Router{ + get: &node{}, + post: &node{}, + put: &node{}, + patch: &node{}, + delete: &node{}, + } +} + +// methodNode returns the root node for a method +func (r *Router) methodNode(method string) *node { + switch method { + case "GET": + return r.get + case "POST": + return r.post + case "PUT": + return r.put + case "PATCH": + return r.patch + case "DELETE": + return r.delete + default: + return nil + } +} + +// AddRoute adds a new route with handler reference +func (r *Router) AddRoute(method, path string, handlerRef int) error { + r.mu.Lock() + defer r.mu.Unlock() + + root := r.methodNode(method) + if root == nil { + return errors.New("unsupported HTTP method") + } + + if path == "/" { + root.handler = handlerRef + return nil + } + + return r.addRoute(root, path, handlerRef) +} + +// addRoute adds a route to the trie +func (r *Router) addRoute(root *node, path string, handlerRef int) error { + segments := r.parseSegments(path) + current := root + + for _, seg := range segments { + isDyn := strings.HasPrefix(seg, ":") + isWC := strings.HasPrefix(seg, "*") + + if isWC && seg != segments[len(segments)-1] { + return errors.New("wildcard must be the last segment") + } + + paramName := "" + if isDyn { + paramName = seg[1:] + seg = ":" + } else if isWC { + paramName = seg[1:] + seg = "*" + } + + // Find or create child + var child *node + for _, c := range current.children { + if c.segment == seg { + child = c + break + } + } + + if child == nil { + child = &node{ + segment: seg, + isDynamic: isDyn, + isWildcard: isWC, + paramName: paramName, + } + current.children = append(current.children, child) + } + + current = child + } + + current.handler = handlerRef + return nil +} + +// parseSegments splits path into segments +func (r *Router) parseSegments(path string) []string { + segments := strings.Split(strings.Trim(path, "/"), "/") + var result []string + for _, seg := range segments { + if seg != "" { + result = append(result, seg) + } + } + return result +} + +// Lookup finds handler and parameters for a method and path +func (r *Router) Lookup(method, path string) (int, *Params, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + root := r.methodNode(method) + if root == nil { + return 0, nil, false + } + + if path == "/" { + if root.handler != 0 { + return root.handler, &Params{}, true + } + return 0, nil, false + } + + segments := r.parseSegments(path) + handler, params := r.match(root, segments, 0) + if handler == 0 { + return 0, nil, false + } + + return handler, params, true +} + +// match traverses the trie to find handler +func (r *Router) match(current *node, segments []string, index int) (int, *Params) { + if index >= len(segments) { + if current.handler != 0 { + return current.handler, &Params{} + } + return 0, nil + } + + segment := segments[index] + + // Check exact match first + for _, child := range current.children { + if child.segment == segment { + handler, params := r.match(child, segments, index+1) + if handler != 0 { + return handler, params + } + } + } + + // Check dynamic match second + for _, child := range current.children { + if child.isDynamic { + handler, params := r.match(child, segments, index+1) + if handler != 0 { + // Prepend this parameter + newParams := &Params{ + Keys: append([]string{child.paramName}, params.Keys...), + Values: append([]string{segment}, params.Values...), + } + return handler, newParams + } + } + } + + // Check wildcard last (catches everything remaining) + for _, child := range current.children { + if child.isWildcard { + remaining := strings.Join(segments[index:], "/") + return child.handler, &Params{ + Keys: []string{child.paramName}, + Values: []string{remaining}, + } + } + } + + return 0, nil +} + +// RemoveRoute removes a route +func (r *Router) RemoveRoute(method, path string) { + r.mu.Lock() + defer r.mu.Unlock() + + root := r.methodNode(method) + if root == nil { + return + } + + if path == "/" { + root.handler = 0 + return + } + + segments := r.parseSegments(path) + r.removeRoute(root, segments, 0) +} + +// removeRoute removes a route from the trie +func (r *Router) removeRoute(current *node, segments []string, index int) { + if index >= len(segments) { + current.handler = 0 + return + } + + segment := segments[index] + + for _, child := range current.children { + if child.segment == segment || + (child.isDynamic && strings.HasPrefix(segment, ":")) || + (child.isWildcard && strings.HasPrefix(segment, "*")) { + r.removeRoute(child, segments, index+1) + break + } + } +} diff --git a/http/server.go b/http/server.go deleted file mode 100644 index a423fbe..0000000 --- a/http/server.go +++ /dev/null @@ -1,167 +0,0 @@ -// server.go - Simplified HTTP server -package http - -import ( - "fmt" - "strings" - "time" - - "Moonshark/config" - "Moonshark/metadata" - "Moonshark/runner" - "Moonshark/utils" - - "github.com/goccy/go-json" - "github.com/valyala/bytebufferpool" - "github.com/valyala/fasthttp" -) - -func NewHttpServer(cfg *config.Config, handler fasthttp.RequestHandler, dbg bool) *fasthttp.Server { - return &fasthttp.Server{ - Handler: handler, - Name: "Moonshark/" + metadata.Version, - ReadTimeout: 30 * time.Second, - WriteTimeout: 30 * time.Second, - IdleTimeout: 120 * time.Second, - MaxRequestBodySize: 16 << 20, - TCPKeepalive: true, - ReduceMemoryUsage: true, - StreamRequestBody: true, - NoDefaultServerHeader: true, - } -} - -func NewPublicHandler(pubDir, prefix string) fasthttp.RequestHandler { - if !strings.HasPrefix(prefix, "/") { - prefix = "/" + prefix - } - if !strings.HasSuffix(prefix, "/") { - prefix += "/" - } - - fs := &fasthttp.FS{ - Root: pubDir, - IndexNames: []string{"index.html"}, - AcceptByteRange: true, - Compress: true, - CompressedFileSuffix: ".gz", - CompressBrotli: true, - PathRewrite: fasthttp.NewPathPrefixStripper(len(prefix) - 1), - } - return fs.NewRequestHandler() -} - -func Send404(ctx *fasthttp.RequestCtx) { - ctx.SetContentType("text/html; charset=utf-8") - ctx.SetStatusCode(fasthttp.StatusNotFound) - ctx.SetBody([]byte(utils.NotFoundPage(ctx.URI().String()))) -} - -func Send500(ctx *fasthttp.RequestCtx, err error) { - ctx.SetContentType("text/html; charset=utf-8") - ctx.SetStatusCode(fasthttp.StatusInternalServerError) - - if err == nil { - ctx.SetBody([]byte(utils.InternalErrorPage(string(ctx.Path()), ""))) - } else { - ctx.SetBody([]byte(utils.InternalErrorPage(string(ctx.Path()), err.Error()))) - } -} - -// ApplyResponse applies a Response to a fasthttp.RequestCtx -func ApplyResponse(resp *runner.Response, ctx *fasthttp.RequestCtx) { - // Set status code - ctx.SetStatusCode(resp.Status) - - // Set headers - for name, value := range resp.Headers { - ctx.Response.Header.Set(name, value) - } - - // Set cookies - for _, cookie := range resp.Cookies { - ctx.Response.Header.SetCookie(cookie) - } - - // Process the body based on its type - if resp.Body == nil { - return - } - - // Check if Content-Type was manually set - contentTypeSet := false - for name := range resp.Headers { - if strings.ToLower(name) == "content-type" { - contentTypeSet = true - break - } - } - - // Get a buffer from the pool - buf := bytebufferpool.Get() - defer bytebufferpool.Put(buf) - - // Set body based on type - switch body := resp.Body.(type) { - case string: - if !contentTypeSet { - ctx.Response.Header.SetContentType("text/plain; charset=utf-8") - } - ctx.SetBodyString(body) - case []byte: - if !contentTypeSet { - ctx.Response.Header.SetContentType("text/plain; charset=utf-8") - } - ctx.SetBody(body) - case map[string]any, map[any]any, []any, []float64, []string, []int, []map[string]any: - // Marshal JSON - if err := json.NewEncoder(buf).Encode(body); err == nil { - if !contentTypeSet { - ctx.Response.Header.SetContentType("application/json") - } - ctx.SetBody(buf.Bytes()) - } else { - // Fallback to string representation - if !contentTypeSet { - ctx.Response.Header.SetContentType("text/plain; charset=utf-8") - } - ctx.SetBodyString(fmt.Sprintf("%v", body)) - } - default: - // Check if it's any other map or slice type - typeStr := fmt.Sprintf("%T", body) - if typeStr[0] == '[' || (len(typeStr) > 3 && typeStr[:3] == "map") { - if err := json.NewEncoder(buf).Encode(body); err == nil { - if !contentTypeSet { - ctx.Response.Header.SetContentType("application/json") - } - ctx.SetBody(buf.Bytes()) - } else { - if !contentTypeSet { - ctx.Response.Header.SetContentType("text/plain; charset=utf-8") - } - ctx.SetBodyString(fmt.Sprintf("%v", body)) - } - } else { - // Default to string representation - if !contentTypeSet { - ctx.Response.Header.SetContentType("text/plain; charset=utf-8") - } - ctx.SetBodyString(fmt.Sprintf("%v", body)) - } - } -} - -/* -func HandleDebugStats(ctx *fasthttp.RequestCtx) { - stats := utils.CollectSystemStats(s.cfg) - stats.Components = utils.ComponentStats{ - RouteCount: 0, // TODO: Get from router - BytecodeBytes: 0, // TODO: Get from router - SessionStats: s.sessionManager.GetCacheStats(), - } - ctx.SetContentType("text/html; charset=utf-8") - ctx.SetStatusCode(fasthttp.StatusOK) - ctx.SetBody([]byte(utils.DebugStatsPage(stats))) -} -*/ diff --git a/sessions/manager.go b/http/sessions/manager.go similarity index 100% rename from sessions/manager.go rename to http/sessions/manager.go diff --git a/sessions/session.go b/http/sessions/session.go similarity index 100% rename from sessions/session.go rename to http/sessions/session.go diff --git a/logger/logger.go b/logger/logger.go deleted file mode 100644 index 2cc5f38..0000000 --- a/logger/logger.go +++ /dev/null @@ -1,208 +0,0 @@ -package logger - -import ( - "fmt" - "io" - "os" - "strings" - "sync" - "sync/atomic" - "time" - - "git.sharkk.net/Go/Color" -) - -// Log levels -const ( - LevelDebug = iota - LevelInfo - LevelWarn - LevelError - LevelFatal -) - -// Level config -var levels = map[int]struct { - tag string - color func(string) string -}{ - LevelDebug: {"D", color.Cyan}, - LevelInfo: {"I", color.Blue}, - LevelWarn: {"W", color.Yellow}, - LevelError: {"E", color.Red}, - LevelFatal: {"F", color.Purple}, -} - -var ( - global *Logger - globalOnce sync.Once -) - -type Logger struct { - out io.Writer - enabled atomic.Bool - timestamp atomic.Bool - debug atomic.Bool - http atomic.Bool - colors atomic.Bool - mu sync.Mutex -} - -func init() { - globalOnce.Do(func() { - global = &Logger{out: os.Stdout} - global.enabled.Store(true) - global.timestamp.Store(true) - global.colors.Store(true) - }) -} - -func applyColor(text string, colorFunc func(string) string) string { - if global.colors.Load() { - return colorFunc(text) - } - return text -} - -func write(level int, msg string) { - if !global.enabled.Load() { - return - } - - if level == LevelDebug && !global.debug.Load() { - return - } - - var parts []string - - if global.timestamp.Load() { - ts := applyColor(time.Now().Format("3:04PM"), color.Gray) - parts = append(parts, ts) - } - - if cfg, ok := levels[level]; ok { - tag := applyColor("["+cfg.tag+"]", cfg.color) - parts = append(parts, tag) - } - - parts = append(parts, msg) - line := strings.Join(parts, " ") + "\n" - - global.mu.Lock() - fmt.Fprint(global.out, line) - if level == LevelFatal { - if f, ok := global.out.(*os.File); ok { - f.Sync() - } - } - global.mu.Unlock() - - if level == LevelFatal { - os.Exit(1) - } -} - -func log(level int, format string, args ...any) { - var msg string - if len(args) > 0 { - msg = fmt.Sprintf(format, args...) - } else { - msg = format - } - write(level, msg) -} - -func Debugf(format string, args ...any) { log(LevelDebug, format, args...) } -func Infof(format string, args ...any) { log(LevelInfo, format, args...) } -func Warnf(format string, args ...any) { log(LevelWarn, format, args...) } -func Errorf(format string, args ...any) { log(LevelError, format, args...) } -func Fatalf(format string, args ...any) { log(LevelFatal, format, args...) } - -func Raw(format string, args ...any) { - if !global.enabled.Load() { - return - } - - var msg string - if len(args) > 0 { - msg = fmt.Sprintf(format, args...) - } else { - msg = format - } - - global.mu.Lock() - fmt.Fprint(global.out, msg+"\n") - global.mu.Unlock() -} - -// Attempts to write the HTTP request result to the log. Will not print if -// http is disabled on the logger. Method and path take byte slices for convenience. -func Request(status int, method, path []byte, duration time.Duration) { - if !global.enabled.Load() { - return - } - - if !global.http.Load() { - return - } - - var statusColor func(string) string - switch { - case status < 300: - statusColor = color.Green - case status < 400: - statusColor = color.Cyan - case status < 500: - statusColor = color.Yellow - default: - statusColor = color.Red - } - - var dur string - us := duration.Microseconds() - switch { - case us < 1000: - dur = fmt.Sprintf("%.0fµs", float64(us)) - case us < 1000000: - dur = fmt.Sprintf("%.1fms", float64(us)/1000) - default: - dur = fmt.Sprintf("%.2fs", duration.Seconds()) - } - - var parts []string - - if global.timestamp.Load() { - ts := applyColor(time.Now().Format("3:04PM"), color.Gray) - parts = append(parts, ts) - } - - parts = append(parts, - applyColor("["+string(method)+"]", color.Gray), - applyColor(fmt.Sprintf("%d", status), statusColor), - applyColor(string(path), color.Gray), - applyColor(dur, color.Gray), - ) - - msg := strings.Join(parts, " ") - - global.mu.Lock() - fmt.Fprint(global.out, msg+"\n") - global.mu.Unlock() -} - -func SetOutput(w io.Writer) { - global.mu.Lock() - global.out = w - global.mu.Unlock() -} - -func Enable() { global.enabled.Store(true) } -func Disable() { global.enabled.Store(false) } -func IsEnabled() bool { return global.enabled.Load() } -func EnableColors() { global.colors.Store(true) } -func DisableColors() { global.colors.Store(false) } -func ColorsEnabled() bool { return global.colors.Load() } -func Timestamp(enabled bool) { global.timestamp.Store(enabled) } -func Debug(enabled bool) { global.debug.Store(enabled) } -func IsDebug() bool { return global.debug.Load() } -func Http(enabled bool) { global.debug.Store(enabled) } diff --git a/moonshark.go b/moonshark.go index 702bd42..f34d399 100644 --- a/moonshark.go +++ b/moonshark.go @@ -1,362 +1,68 @@ package main import ( - "Moonshark/config" - "Moonshark/http" - "Moonshark/logger" - "Moonshark/metadata" - "Moonshark/router" - "Moonshark/runner" - "Moonshark/sessions" - "Moonshark/utils" - "Moonshark/watchers" - "bytes" - "context" - "flag" "fmt" "os" "os/signal" - "path/filepath" - "strconv" - "strings" "syscall" - "time" - color "git.sharkk.net/Go/Color" + "Moonshark/http" - fin "git.sharkk.net/Sharkk/Fin" - "github.com/valyala/fasthttp" -) - -var ( - cfg *config.Config // Server config from Fin file - rtr *router.Router // Lua file router - rnr *runner.Runner // Lua runner - svr *fasthttp.Server // FastHTTP server - pub fasthttp.RequestHandler // Public asset handler - snm *sessions.SessionManager // Session data manager - wmg *watchers.WatcherManager // Watcher manager - dbg bool // Debug mode flag - pubPfx []byte // Cached public asset prefix + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) func main() { - cfgPath := flag.String("config", "config", "Path to Fin config file") - dbgFlag := flag.Bool("debug", false, "Force debug mode") - sptPath := flag.String("script", "", "Path to Lua script to execute once") - flag.Parse() - - // Init sequence - sptMode := *sptPath != "" - color.SetColors(color.DetectShellColors()) - banner(sptMode) - - // Load Fin-based config - cfg = config.New(fin.LoadFromFile(*cfgPath)) - - // Setup debug mode - dbg = *dbgFlag || cfg.Server.Debug - logger.Debug(dbg) - logger.Debugf("Debug logging enabled") // Only prints if dbg is true - utils.Debug(dbg) // @TODO find a better way to do this - - // Determine Lua runner pool size - poolSize := cfg.Runner.PoolSize - if sptMode { - poolSize = 1 + if len(os.Args) < 2 { + fmt.Fprintf(os.Stderr, "Usage: %s \n", os.Args[0]) + os.Exit(1) } - // Set up the Lua runner - if err := initRunner(poolSize); err != nil { - logger.Fatalf("Runner failed to init: %v", err) + luaFile := os.Args[1] + + if _, err := os.Stat(luaFile); os.IsNotExist(err) { + fmt.Fprintf(os.Stderr, "Error: File '%s' not found\n", luaFile) + os.Exit(1) } - // If in script mode, attempt to run the Lua script at the given path - if sptMode { - if err := handleScriptMode(*sptPath); err != nil { - logger.Fatalf("Script execution failed: %v", err) - } - - shutdown() - return + // Create long-lived LuaJIT state + L := luajit.New(true) + if L == nil { + fmt.Fprintf(os.Stderr, "Error: Failed to create Lua state\n") + os.Exit(1) } - - // Set up the Lua router - if err := initRouter(); err != nil { - logger.Fatalf("Router failed to init: %s", color.Red(err.Error())) - } - - // Set up the file watcher manager - if err := setupWatchers(); err != nil { - logger.Fatalf("Watcher manager failed to init: %s", color.Red(err.Error())) - } - - // Set up the HTTP portion of the server - logger.Http(cfg.Server.HTTPLogging) // Whether we'll log HTTP request results - svr = http.NewHttpServer(cfg, requestMux, dbg) - pub = http.NewPublicHandler(cfg.Dirs.Public, cfg.Server.PublicPrefix) - pubPfx = []byte(cfg.Server.PublicPrefix) // Avoids casting to []byte when check prefixes - snm = sessions.NewSessionManager(sessions.DefaultMaxSessions) - - // Start the HTTP server - logger.Infof("Surf's up on port %s!", color.Cyan(strconv.Itoa(cfg.Server.Port))) - go func() { - if err := svr.ListenAndServe(":" + strconv.Itoa(cfg.Server.Port)); err != nil { - if err.Error() != "http: Server closed" { - logger.Errorf("Server error: %v", err) - } - } + defer func() { + L.Cleanup() + L.Close() }() - // Handle a shutdown signal - stop := make(chan os.Signal, 1) - signal.Notify(stop, os.Interrupt, syscall.SIGTERM) - <-stop - - fmt.Print("\n") - logger.Infof("Shutdown signal received") - shutdown() -} - -// This is the primary request handler mux - determines whether we need to handle a Lua -// route or if we're serving a static file. -func requestMux(ctx *fasthttp.RequestCtx) { - start := time.Now() - method := ctx.Method() - path := ctx.Path() - - // Handle static file request - if bytes.HasPrefix(path, pubPfx) { - pub(ctx) - logRequest(ctx, method, path, start) - return + // Register HTTP functions + if err := http.RegisterHTTPFunctions(L); err != nil { + fmt.Fprintf(os.Stderr, "Error registering HTTP functions: %v\n", err) + os.Exit(1) } - // See if the requested route even exists - bytecode, params, found := rtr.Lookup(string(method), string(path)) - if !found { - http.Send404(ctx) - logRequest(ctx, method, path, start) - return + // Execute the Lua file + if err := L.DoFile(luaFile); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) } - // If there's no bytecode then it's an internal server error - if len(bytecode) == 0 { - http.Send500(ctx, nil) - logRequest(ctx, method, path, start) - } - - // We've made it this far so the endpoint will likely load. Let's get any session data - // for this request - session := snm.GetSessionFromRequest(ctx) - - // Let's build an HTTP context for the Lua runner to consume - luaCtx := runner.NewHTTPContext(ctx, params, session) - defer luaCtx.Release() - - // Ask the runner to execute our endpoint with our context - res, err := rnr.Execute(bytecode, luaCtx) - if err != nil { - logger.Errorf("Lua execution error: %v", err) - http.Send500(ctx, err) - logRequest(ctx, method, path, start) - return - } - - // Sweet, our execution went through! Let's now use the Response we got and build the HTTP response, then return - // the response object to be cleaned. After, we'll log our request cus we are *done* - applyResponse(ctx, res, session) - runner.ReleaseResponse(res) - logRequest(ctx, method, path, start) -} - -func applyResponse(ctx *fasthttp.RequestCtx, resp *runner.Response, session *sessions.Session) { - // Handle session updates - if len(resp.SessionData) > 0 { - if _, clearAll := resp.SessionData["__clear_all"]; clearAll { - session.Clear() - session.ClearFlash() - delete(resp.SessionData, "__clear_all") - } - - for k, v := range resp.SessionData { - if v == "__DELETE__" { - session.Delete(k) - } else { - session.Set(k, v) + // Handle return value for immediate exit + if L.GetTop() > 0 { + if L.IsNumber(1) { + exitCode := int(L.ToNumber(1)) + if exitCode != 0 { + os.Exit(exitCode) } + } else if L.IsBoolean(1) && !L.ToBoolean(1) { + os.Exit(1) } } - // Handle flash data - if flashData, ok := resp.Metadata["flash"].(map[string]any); ok { - for k, v := range flashData { - if err := session.FlashSafe(k, v); err != nil && dbg { - logger.Warnf("Error setting flash data %s: %v", k, err) - } - } - } - - // Apply session cookie - snm.ApplySessionCookie(ctx, session) - - // Apply HTTP response - http.ApplyResponse(resp, ctx) -} - -// Attempts to start the Lua runner. poolSize allows overriding the config, like for script mode. A poolSize of -// 0 will default to the config, and if the config is 0 then it will default to GOMAXPROCS. -func initRunner(poolSize int) error { - for _, dir := range cfg.Dirs.Libs { - if !dirExists(dir) { - logger.Warnf("Lib directory not found... %s", color.Yellow(dir)) - } - } - - runner, err := runner.NewRunner(cfg, poolSize) - if err != nil { - return fmt.Errorf("lua runner init failed: %v", err) - } - rnr = runner - - logger.Infof("LuaRunner is g2g with %s states!", color.Yellow(strconv.Itoa(poolSize))) - return nil -} - -// Attempt to spin up the Lua router. Attempts to create the routes directory if it doesn't exist, -// since it's required for Moonshark to work. -func initRouter() error { - if err := os.MkdirAll(cfg.Dirs.Routes, 0755); err != nil { - return fmt.Errorf("failed to create routes directory: %w", err) - } - - router, err := router.New(cfg.Dirs.Routes) - if err != nil { - return fmt.Errorf("lua router init failed: %v", err) - } - rtr = router - - logger.Infof("LuaRouter is g2g! %s", color.Yellow(cfg.Dirs.Routes)) - return nil -} - -// Set up the file watchers. -func setupWatchers() error { - wmg = watchers.NewWatcherManager() - - // Router watcher - err := wmg.WatchDirectory(watchers.WatcherConfig{ - Dir: cfg.Dirs.Routes, - Callback: rtr.Refresh, - Recursive: true, - }) - if err != nil { - return fmt.Errorf("failed to watch routes directory: %v", err) - } - - logger.Infof("Started watching Lua routes! %s", color.Yellow(cfg.Dirs.Routes)) - - // Libs watchers - for _, dir := range cfg.Dirs.Libs { - err := wmg.WatchDirectory(watchers.WatcherConfig{ - Dir: dir, - Callback: func(changes []watchers.FileChange) error { - for _, change := range changes { - if !change.IsDeleted && strings.HasSuffix(change.Path, ".lua") { - rnr.NotifyFileChanged(change.Path) - } - } - return nil - }, - Recursive: true, - }) - if err != nil { - return fmt.Errorf("failed to watch modules directory: %v", err) - } - - logger.Infof("Started watching Lua modules! %s", color.Yellow(dir)) - } - - return nil -} - -// Attempts to execute the Lua script at the given path inside a fully initialized sandbox environment. Handy -// for pre-launch tasks and the like. -func handleScriptMode(path string) error { - path, err := filepath.Abs(path) - if err != nil { - return fmt.Errorf("failed to resolve script path: %v", err) - } - - if _, err := os.Stat(path); os.IsNotExist(err) { - return fmt.Errorf("script file not found: %s", path) - } - - logger.Infof("Executing: %s", path) - - resp, err := rnr.RunScriptFile(path) - if err != nil { - return fmt.Errorf("execution failed: %v", err) - } - - if resp != nil && resp.Body != nil { - logger.Infof("Script result: %v", resp.Body) - } else { - logger.Infof("Script executed successfully (no return value)") - } - - return nil -} - -func shutdown() { - logger.Infof("Shutting down...") - - // Close down the HTTP server - if svr != nil { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if err := svr.ShutdownWithContext(ctx); err != nil { - logger.Errorf("HTTP server shutdown error: %v", err) - } - } - - // Close down the Lua runner if it exists - if rnr != nil { - rnr.Close() - } - - // Close down the watcher manager if it exists - if wmg != nil { - wmg.Close() - } - - logger.Infof("Shutdown complete") -} - -// Print our super-awesome banner with the current version! -func banner(scriptMode bool) { - if scriptMode { - fmt.Println(color.Blue(fmt.Sprintf("Moonshark %s << Script Mode >>", metadata.Version))) - return - } - - banner := ` - _____ _________.__ __ - / \ ____ ____ ____ / _____/| |__ _____ _______| | __ - / \ / \ / _ \ / _ \ / \ \_____ \ | | \\__ \\_ __ \ |/ / -/ Y ( <_> | <_> ) | \/ \| Y \/ __ \| | \/ < -\____|__ /\____/ \____/|___| /_______ /|___| (____ /__| |__|_ \ %s - \/ \/ \/ \/ \/ \/ - ` - fmt.Println(color.Blue(fmt.Sprintf(banner, metadata.Version))) -} - -func dirExists(path string) bool { - info, err := os.Stat(path) - return err == nil && info.IsDir() -} - -func logRequest(ctx *fasthttp.RequestCtx, method, path []byte, start time.Time) { - logger.Request(ctx.Response.StatusCode(), method, path, time.Since(start)) + // Keep running for HTTP server + fmt.Println("Script executed. Press Ctrl+C to exit.") + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c + fmt.Println("\nShutting down...") } diff --git a/router/router.go b/router/router.go deleted file mode 100644 index b93aa62..0000000 --- a/router/router.go +++ /dev/null @@ -1,502 +0,0 @@ -package router - -import ( - "Moonshark/watchers" - "errors" - "os" - "path/filepath" - "strings" - "sync" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" - "github.com/VictoriaMetrics/fastcache" -) - -// node represents a node in the radix trie -type node struct { - segment string - bytecode []byte - scriptPath string - children []*node - isDynamic bool - isWildcard bool - maxParams uint8 -} - -// Router is a filesystem-based HTTP router for Lua files with bytecode caching -type Router struct { - routesDir string - get, post, put, patch, delete *node - bytecodeCache *fastcache.Cache - compileState *luajit.State - compileMu sync.Mutex - paramsBuffer []string - middlewareFiles map[string][]string // filesystem path -> middleware file paths -} - -// Params holds URL parameters -type Params struct { - Keys []string - Values []string -} - -// Get returns a parameter value by name -func (p *Params) Get(name string) string { - for i, key := range p.Keys { - if key == name && i < len(p.Values) { - return p.Values[i] - } - } - return "" -} - -// New creates a new Router instance -func New(routesDir string) (*Router, error) { - info, err := os.Stat(routesDir) - if err != nil { - return nil, err - } - if !info.IsDir() { - return nil, errors.New("routes path is not a directory") - } - - compileState := luajit.New() - if compileState == nil { - return nil, errors.New("failed to create Lua compile state") - } - - r := &Router{ - routesDir: routesDir, - get: &node{}, - post: &node{}, - put: &node{}, - patch: &node{}, - delete: &node{}, - bytecodeCache: fastcache.New(32 * 1024 * 1024), // 32MB - compileState: compileState, - paramsBuffer: make([]string, 64), - middlewareFiles: make(map[string][]string), - } - - return r, r.buildRoutes() -} - -// methodNode returns the root node for a method -func (r *Router) methodNode(method string) *node { - switch method { - case "GET": - return r.get - case "POST": - return r.post - case "PUT": - return r.put - case "PATCH": - return r.patch - case "DELETE": - return r.delete - default: - return nil - } -} - -// buildRoutes scans the routes directory and builds the routing tree -func (r *Router) buildRoutes() error { - r.middlewareFiles = make(map[string][]string) - - // First pass: collect all middleware files - err := filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error { - if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".lua") { - return err - } - - if strings.TrimSuffix(info.Name(), ".lua") == "middleware" { - relDir, err := filepath.Rel(r.routesDir, filepath.Dir(path)) - if err != nil { - return err - } - - fsPath := "/" - if relDir != "." { - fsPath = "/" + strings.ReplaceAll(relDir, "\\", "/") - } - - r.middlewareFiles[fsPath] = append(r.middlewareFiles[fsPath], path) - } - - return nil - }) - - if err != nil { - return err - } - - // Second pass: build routes - return filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error { - if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".lua") { - return err - } - - fileName := strings.TrimSuffix(info.Name(), ".lua") - - // Skip middleware files - if fileName == "middleware" { - return nil - } - - // Get relative path from routes directory - relPath, err := filepath.Rel(r.routesDir, path) - if err != nil { - return err - } - - // Get filesystem path (includes groups) - fsPath := "/" + strings.ReplaceAll(filepath.Dir(relPath), "\\", "/") - if fsPath == "/." { - fsPath = "/" - } - - // Get URL path (excludes groups) - urlPath := r.parseURLPath(fsPath) - - // Handle method files (get.lua, post.lua, etc.) - method := strings.ToUpper(fileName) - root := r.methodNode(method) - if root != nil { - return r.addRoute(root, urlPath, fsPath, path) - } - - // Handle index files - register for all methods - if fileName == "index" { - for _, method := range []string{"GET", "POST", "PUT", "PATCH", "DELETE"} { - if root := r.methodNode(method); root != nil { - if err := r.addRoute(root, urlPath, fsPath, path); err != nil { - return err - } - } - } - return nil - } - - // Handle named route files - register as GET by default - namedPath := urlPath - if urlPath == "/" { - namedPath = "/" + fileName - } else { - namedPath = urlPath + "/" + fileName - } - return r.addRoute(r.get, namedPath, fsPath, path) - }) -} - -// parseURLPath strips group segments from filesystem path -func (r *Router) parseURLPath(fsPath string) string { - segments := strings.Split(strings.Trim(fsPath, "/"), "/") - var urlSegments []string - - for _, segment := range segments { - if segment == "" { - continue - } - // Skip group segments (enclosed in parentheses) - if strings.HasPrefix(segment, "(") && strings.HasSuffix(segment, ")") { - continue - } - urlSegments = append(urlSegments, segment) - } - - if len(urlSegments) == 0 { - return "/" - } - return "/" + strings.Join(urlSegments, "/") -} - -// getMiddlewareChain returns middleware files that apply to the given filesystem path -func (r *Router) getMiddlewareChain(fsPath string) []string { - var chain []string - - pathParts := strings.Split(strings.Trim(fsPath, "/"), "/") - if pathParts[0] == "" { - pathParts = []string{} - } - - // Add root middleware - if mw, exists := r.middlewareFiles["/"]; exists { - chain = append(chain, mw...) - } - - // Add middleware from each path level (including groups) - currentPath := "" - for _, part := range pathParts { - currentPath += "/" + part - if mw, exists := r.middlewareFiles[currentPath]; exists { - chain = append(chain, mw...) - } - } - - return chain -} - -// buildCombinedSource combines middleware and handler source -func (r *Router) buildCombinedSource(fsPath, scriptPath string) (string, error) { - var combined strings.Builder - - // Add middleware in order - middlewareChain := r.getMiddlewareChain(fsPath) - for _, mwPath := range middlewareChain { - content, err := os.ReadFile(mwPath) - if err != nil { - return "", err - } - combined.WriteString("-- Middleware: ") - combined.WriteString(mwPath) - combined.WriteString("\n") - combined.Write(content) - combined.WriteString("\n") - } - - // Add main handler - content, err := os.ReadFile(scriptPath) - if err != nil { - return "", err - } - combined.WriteString("-- Handler: ") - combined.WriteString(scriptPath) - combined.WriteString("\n") - combined.Write(content) - - return combined.String(), nil -} - -// addRoute adds a new route to the trie with bytecode compilation -func (r *Router) addRoute(root *node, urlPath, fsPath, scriptPath string) error { - // Build combined source with middleware - combinedSource, err := r.buildCombinedSource(fsPath, scriptPath) - if err != nil { - return err - } - - // Compile bytecode - r.compileMu.Lock() - bytecode, err := r.compileState.CompileBytecode(combinedSource, scriptPath) - r.compileMu.Unlock() - - if err != nil { - return err - } - - // Cache bytecode - cacheKey := hashString(scriptPath) - r.bytecodeCache.Set(uint64ToBytes(cacheKey), bytecode) - - if urlPath == "/" { - root.bytecode = bytecode - root.scriptPath = scriptPath - return nil - } - - current := root - pos := 0 - paramCount := uint8(0) - - for { - seg, newPos, more := readSegment(urlPath, pos) - if seg == "" { - break - } - - isDyn := len(seg) > 2 && seg[0] == '[' && seg[len(seg)-1] == ']' - isWC := len(seg) > 0 && seg[0] == '*' - - if isWC && more { - return errors.New("wildcard must be the last segment") - } - - if isDyn || isWC { - paramCount++ - } - - // Find or create child - var child *node - for _, c := range current.children { - if c.segment == seg { - child = c - break - } - } - - if child == nil { - child = &node{ - segment: seg, - isDynamic: isDyn, - isWildcard: isWC, - } - current.children = append(current.children, child) - } - - if child.maxParams < paramCount { - child.maxParams = paramCount - } - - current = child - pos = newPos - } - - current.bytecode = bytecode - current.scriptPath = scriptPath - return nil -} - -// readSegment extracts the next path segment -func readSegment(path string, start int) (segment string, end int, hasMore bool) { - if start >= len(path) { - return "", start, false - } - if path[start] == '/' { - start++ - } - if start >= len(path) { - return "", start, false - } - end = start - for end < len(path) && path[end] != '/' { - end++ - } - return path[start:end], end, end < len(path) -} - -// Lookup finds bytecode and parameters for a method and path -func (r *Router) Lookup(method, path string) ([]byte, *Params, bool) { - root := r.methodNode(method) - if root == nil { - return nil, nil, false - } - - if path == "/" { - if root.bytecode != nil { - return root.bytecode, &Params{}, true - } - return nil, nil, false - } - - // Prepare params buffer - buffer := r.paramsBuffer - if cap(buffer) < int(root.maxParams) { - buffer = make([]string, root.maxParams) - r.paramsBuffer = buffer - } - buffer = buffer[:0] - - var keys []string - bytecode, paramCount, found := r.match(root, path, 0, &buffer, &keys) - if !found { - return nil, nil, false - } - - params := &Params{ - Keys: keys[:paramCount], - Values: buffer[:paramCount], - } - - return bytecode, params, true -} - -// match traverses the trie to find bytecode -func (r *Router) match(current *node, path string, start int, params *[]string, keys *[]string) ([]byte, int, bool) { - paramCount := 0 - - // Check wildcard first - for _, c := range current.children { - if c.isWildcard { - rem := path[start:] - if len(rem) > 0 && rem[0] == '/' { - rem = rem[1:] - } - *params = append(*params, rem) - *keys = append(*keys, strings.TrimPrefix(c.segment, "*")) - return c.bytecode, 1, c.bytecode != nil - } - } - - seg, pos, more := readSegment(path, start) - if seg == "" { - return current.bytecode, 0, current.bytecode != nil - } - - for _, c := range current.children { - if c.segment == seg || c.isDynamic { - if c.isDynamic { - *params = append(*params, seg) - paramName := c.segment[1 : len(c.segment)-1] // Remove [ ] - *keys = append(*keys, paramName) - paramCount++ - } - - if !more { - return c.bytecode, paramCount, c.bytecode != nil - } - - bytecode, nestedCount, ok := r.match(c, path, pos, params, keys) - if ok { - return bytecode, paramCount + nestedCount, true - } - - // Backtrack on failure - if c.isDynamic { - *params = (*params)[:len(*params)-1] - *keys = (*keys)[:len(*keys)-1] - } - } - } - - return nil, 0, false -} - -// GetBytecode gets cached bytecode by script path -func (r *Router) GetBytecode(scriptPath string) []byte { - cacheKey := hashString(scriptPath) - return r.bytecodeCache.Get(nil, uint64ToBytes(cacheKey)) -} - -// Refresh rebuilds the router -func (r *Router) Refresh(changes []watchers.FileChange) error { - r.get = &node{} - r.post = &node{} - r.put = &node{} - r.patch = &node{} - r.delete = &node{} - r.middlewareFiles = make(map[string][]string) - r.bytecodeCache.Reset() - return r.buildRoutes() -} - -// Close cleans up resources -func (r *Router) Close() { - r.compileMu.Lock() - if r.compileState != nil { - r.compileState.Close() - r.compileState = nil - } - r.compileMu.Unlock() -} - -// Helper functions from cache.go -func hashString(s string) uint64 { - h := uint64(5381) - for i := 0; i < len(s); i++ { - h = ((h << 5) + h) + uint64(s[i]) - } - return h -} - -func uint64ToBytes(n uint64) []byte { - b := make([]byte, 8) - b[0] = byte(n) - b[1] = byte(n >> 8) - b[2] = byte(n >> 16) - b[3] = byte(n >> 24) - b[4] = byte(n >> 32) - b[5] = byte(n >> 40) - b[6] = byte(n >> 48) - b[7] = byte(n >> 56) - return b -} diff --git a/router/router_test.go b/router/router_test.go deleted file mode 100644 index be44d27..0000000 --- a/router/router_test.go +++ /dev/null @@ -1,318 +0,0 @@ -package router - -import ( - "os" - "path/filepath" - "testing" -) - -func setupTestRoutes(t testing.TB) string { - t.Helper() - - tempDir := t.TempDir() - - // Create test route files - routes := map[string]string{ - "index.lua": `return "home"`, - "about.lua": `return "about"`, - "api/users.lua": `return "users"`, - "api/users/get.lua": `return "get_users"`, - "api/users/post.lua": `return "create_user"`, - "api/users/[id].lua": `return "user_" .. id`, - "api/posts/[slug]/comments.lua": `return "comments_" .. slug`, - "files/*path.lua": `return "file_" .. path`, - "middleware.lua": `-- root middleware`, - "api/middleware.lua": `-- api middleware`, - } - - for path, content := range routes { - fullPath := filepath.Join(tempDir, path) - dir := filepath.Dir(fullPath) - - if err := os.MkdirAll(dir, 0755); err != nil { - t.Fatal(err) - } - - if err := os.WriteFile(fullPath, []byte(content), 0644); err != nil { - t.Fatal(err) - } - } - - return tempDir -} - -func TestRouterBasicFunctionality(t *testing.T) { - routesDir := setupTestRoutes(t) - router, err := New(routesDir) - if err != nil { - t.Fatal(err) - } - defer router.Close() - - tests := []struct { - method string - path string - expected bool - params map[string]string - }{ - {"GET", "/", true, nil}, - {"GET", "/about", true, nil}, - {"GET", "/api/users", true, nil}, - {"GET", "/api/users", true, nil}, - {"POST", "/api/users", true, nil}, - {"GET", "/api/users/123", true, map[string]string{"id": "123"}}, - {"GET", "/api/posts/hello-world/comments", true, map[string]string{"slug": "hello-world"}}, - {"GET", "/files/docs/readme.txt", true, map[string]string{"path": "docs/readme.txt"}}, - {"GET", "/nonexistent", false, nil}, - {"DELETE", "/api/users", false, nil}, - } - - for _, tt := range tests { - t.Run(tt.method+"_"+tt.path, func(t *testing.T) { - bytecode, params, found := router.Lookup(tt.method, tt.path) - - if found != tt.expected { - t.Errorf("expected found=%v, got %v", tt.expected, found) - } - - if tt.expected { - if bytecode == nil { - t.Error("expected bytecode, got nil") - } - - if tt.params != nil { - for key, expectedValue := range tt.params { - if actualValue := params.Get(key); actualValue != expectedValue { - t.Errorf("param %s: expected %s, got %s", key, expectedValue, actualValue) - } - } - } - } - }) - } -} - -func TestRouterParamsStruct(t *testing.T) { - params := &Params{ - Keys: []string{"id", "slug"}, - Values: []string{"123", "hello"}, - } - - if params.Get("id") != "123" { - t.Errorf("expected '123', got '%s'", params.Get("id")) - } - - if params.Get("slug") != "hello" { - t.Errorf("expected 'hello', got '%s'", params.Get("slug")) - } - - if params.Get("missing") != "" { - t.Errorf("expected empty string for missing param, got '%s'", params.Get("missing")) - } -} - -func TestRouterMethodNodes(t *testing.T) { - routesDir := setupTestRoutes(t) - router, err := New(routesDir) - if err != nil { - t.Fatal(err) - } - defer router.Close() - - // Test that different methods work independently - _, _, foundGet := router.Lookup("GET", "/api/users") - _, _, foundPost := router.Lookup("POST", "/api/users") - _, _, foundPut := router.Lookup("PUT", "/api/users") - - if !foundGet { - t.Error("GET /api/users should be found") - } - if !foundPost { - t.Error("POST /api/users should be found") - } - if foundPut { - t.Error("PUT /api/users should not be found") - } -} - -func TestRouterWildcardValidation(t *testing.T) { - tempDir := t.TempDir() - - // Create invalid wildcard route (not at end) - invalidPath := filepath.Join(tempDir, "bad/*path/more.lua") - if err := os.MkdirAll(filepath.Dir(invalidPath), 0755); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(invalidPath, []byte(`return "bad"`), 0644); err != nil { - t.Fatal(err) - } - - _, err := New(tempDir) - if err == nil { - t.Error("expected error for wildcard not at end") - } -} - -func BenchmarkLookupStatic(b *testing.B) { - routesDir := setupTestRoutes(b) - router, err := New(routesDir) - if err != nil { - b.Fatal(err) - } - defer router.Close() - - method := "GET" - path := "/api/users" - - b.ResetTimer() - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - _, _, _ = router.Lookup(method, path) - } -} - -func BenchmarkLookupDynamic(b *testing.B) { - routesDir := setupTestRoutes(b) - router, err := New(routesDir) - if err != nil { - b.Fatal(err) - } - defer router.Close() - - method := "GET" - path := "/api/users/12345" - - b.ResetTimer() - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - _, _, _ = router.Lookup(method, path) - } -} - -func BenchmarkLookupWildcard(b *testing.B) { - routesDir := setupTestRoutes(b) - router, err := New(routesDir) - if err != nil { - b.Fatal(err) - } - defer router.Close() - - method := "GET" - path := "/files/docs/deep/nested/file.txt" - - b.ResetTimer() - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - _, _, _ = router.Lookup(method, path) - } -} - -func BenchmarkLookupComplex(b *testing.B) { - routesDir := setupTestRoutes(b) - router, err := New(routesDir) - if err != nil { - b.Fatal(err) - } - defer router.Close() - - method := "GET" - path := "/api/posts/my-blog-post-title/comments" - - b.ResetTimer() - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - _, _, _ = router.Lookup(method, path) - } -} - -func BenchmarkLookupNotFound(b *testing.B) { - routesDir := setupTestRoutes(b) - router, err := New(routesDir) - if err != nil { - b.Fatal(err) - } - defer router.Close() - - method := "GET" - path := "/this/path/does/not/exist" - - b.ResetTimer() - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - _, _, _ = router.Lookup(method, path) - } -} - -func BenchmarkLookupMixed(b *testing.B) { - routesDir := setupTestRoutes(b) - router, err := New(routesDir) - if err != nil { - b.Fatal(err) - } - defer router.Close() - - paths := []string{ - "/", - "/about", - "/api/users", - "/api/users/123", - "/api/posts/hello/comments", - "/files/document.pdf", - "/nonexistent", - } - method := "GET" - - b.ResetTimer() - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - path := paths[i%len(paths)] - _, _, _ = router.Lookup(method, path) - } -} - -// Comparison benchmarks for string vs byte slice performance -func BenchmarkLookupStringConversion(b *testing.B) { - routesDir := setupTestRoutes(b) - router, err := New(routesDir) - if err != nil { - b.Fatal(err) - } - defer router.Close() - - methodStr := "GET" - pathStr := "/api/users/12345" - - b.ResetTimer() - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - // Direct string usage - _, _, _ = router.Lookup(methodStr, pathStr) - } -} - -func BenchmarkLookupPreallocated(b *testing.B) { - routesDir := setupTestRoutes(b) - router, err := New(routesDir) - if err != nil { - b.Fatal(err) - } - defer router.Close() - - // Pre-allocated strings (optimal case) - method := "GET" - path := "/api/users/12345" - - b.ResetTimer() - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - _, _, _ = router.Lookup(method, path) - } -} diff --git a/runner/context.go b/runner/context.go deleted file mode 100644 index d7388ad..0000000 --- a/runner/context.go +++ /dev/null @@ -1,58 +0,0 @@ -package runner - -import ( - "sync" -) - -// Generic interface to support different types of execution contexts for the runner -type ExecutionContext interface { - Get(key string) any - Set(key string, value any) - ToMap() map[string]any - Release() -} - -// This is a generic context that satisfies the runner's ExecutionContext interface -type Context struct { - Values map[string]any // Any data we want to pass to the state's global ctx table. -} - -// Context pool to reduce allocations -var contextPool = sync.Pool{ - New: func() any { - return &Context{ - Values: make(map[string]any, 32), - } - }, -} - -// Gets a new context from the pool -func NewContext() *Context { - ctx := contextPool.Get().(*Context) - return ctx -} - -// Release returns the context to the pool after clearing its values -func (c *Context) Release() { - // Clear all values to prevent data leakage - for k := range c.Values { - delete(c.Values, k) - } - - contextPool.Put(c) -} - -// Set adds a value to the context -func (c *Context) Set(key string, value any) { - c.Values[key] = value -} - -// Get retrieves a value from the context -func (c *Context) Get(key string) any { - return c.Values[key] -} - -// We can just return the Values map as it's already g2g for Lua -func (c *Context) ToMap() map[string]any { - return c.Values -} diff --git a/runner/embed.go b/runner/embed.go deleted file mode 100644 index dcafd2b..0000000 --- a/runner/embed.go +++ /dev/null @@ -1,130 +0,0 @@ -package runner - -import ( - "Moonshark/logger" - _ "embed" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -//go:embed lua/sandbox.lua -var sandboxLuaCode string - -//go:embed lua/json.lua -var jsonLuaCode string - -//go:embed lua/sqlite.lua -var sqliteLuaCode string - -//go:embed lua/fs.lua -var fsLuaCode string - -//go:embed lua/util.lua -var utilLuaCode string - -//go:embed lua/string.lua -var stringLuaCode string - -//go:embed lua/table.lua -var tableLuaCode string - -//go:embed lua/crypto.lua -var cryptoLuaCode string - -//go:embed lua/time.lua -var timeLuaCode string - -//go:embed lua/math.lua -var mathLuaCode string - -//go:embed lua/env.lua -var envLuaCode string - -//go:embed lua/http.lua -var httpLuaCode string - -//go:embed lua/cookie.lua -var cookieLuaCode string - -//go:embed lua/csrf.lua -var csrfLuaCode string - -//go:embed lua/render.lua -var renderLuaCode string - -//go:embed lua/session.lua -var sessionLuaCode string - -//go:embed lua/timestamp.lua -var timestampLuaCode string - -// Module represents a Lua module to load -type Module struct { - name string - code string - global bool // true if module defines globals, false if it returns a table -} - -var modules = []Module{ - {"http", httpLuaCode, true}, - {"string", stringLuaCode, false}, - {"table", tableLuaCode, false}, - {"util", utilLuaCode, true}, - {"cookie", cookieLuaCode, true}, - {"session", sessionLuaCode, true}, - {"csrf", csrfLuaCode, true}, - {"render", renderLuaCode, true}, - {"json", jsonLuaCode, true}, - {"fs", fsLuaCode, true}, - {"crypto", cryptoLuaCode, true}, - {"time", timeLuaCode, false}, - {"math", mathLuaCode, false}, - {"env", envLuaCode, true}, - {"sqlite", sqliteLuaCode, true}, - {"timestamp", timestampLuaCode, false}, -} - -// loadModule loads a single module into the Lua state -func loadModule(state *luajit.State, m Module) error { - if m.global { - // Module defines globals directly, just execute it - return state.DoString(m.code) - } - - // Module returns a table, capture it and set as global - if err := state.LoadString(m.code); err != nil { - return err - } - if err := state.Call(0, 1); err != nil { - return err - } - state.SetGlobal(m.name) - return nil -} - -// loadSandboxIntoState loads all modules and sandbox into a Lua state -func loadSandboxIntoState(state *luajit.State, verbose bool) error { - // Load all utility modules - for _, module := range modules { - if err := loadModule(state, module); err != nil { - if verbose { - logger.Errorf("Failed to load %s module: %v", module.name, err) - } - return err - } - if verbose { - logger.Debugf("Loaded %s.lua", module.name) - } - } - - // Initialize module-specific globals - if err := state.DoString(`__active_sqlite_connections = {}`); err != nil { - return err - } - - // Load sandbox last - defines __execute and core functions - if verbose { - logger.Debugf("Loading sandbox.lua") - } - return state.DoString(sandboxLuaCode) -} diff --git a/runner/httpContext.go b/runner/httpContext.go deleted file mode 100644 index 364cde3..0000000 --- a/runner/httpContext.go +++ /dev/null @@ -1,136 +0,0 @@ -package runner - -import ( - "Moonshark/router" - "Moonshark/runner/lualibs" - "Moonshark/sessions" - "Moonshark/utils" - "sync" - - "maps" - - "github.com/valyala/fasthttp" -) - -// A prebuilt, ready-to-go context for HTTP requests to the runner. -type HTTPContext struct { - Values map[string]any // Contains all context data for Lua - - // Separate maps for efficient access during context building - headers map[string]string - cookies map[string]string - query map[string]string - params map[string]string - form map[string]any - session map[string]any - env map[string]any -} - -// HTTP context pool to reduce allocations -var httpContextPool = sync.Pool{ - New: func() any { - return &HTTPContext{ - Values: make(map[string]any, 32), - headers: make(map[string]string, 16), - cookies: make(map[string]string, 8), - query: make(map[string]string, 8), - params: make(map[string]string, 4), - form: make(map[string]any, 8), - session: make(map[string]any, 4), - env: make(map[string]any, 16), - } - }, -} - -// Get a clean HTTP context from the pool and build it up with an HTTP request, router params and session data -func NewHTTPContext(httpCtx *fasthttp.RequestCtx, params *router.Params, session *sessions.Session) *HTTPContext { - ctx := httpContextPool.Get().(*HTTPContext) - - // Extract headers - httpCtx.Request.Header.VisitAll(func(key, value []byte) { - ctx.headers[string(key)] = string(value) - }) - - // Extract cookies - httpCtx.Request.Header.VisitAllCookie(func(key, value []byte) { - ctx.cookies[string(key)] = string(value) - }) - - // Extract query params - httpCtx.QueryArgs().VisitAll(func(key, value []byte) { - ctx.query[string(key)] = string(value) - }) - - // Extract route parameters - if params != nil { - for i := range min(len(params.Keys), len(params.Values)) { - ctx.params[params.Keys[i]] = params.Values[i] - } - } - - // Extract form data if present - if httpCtx.IsPost() || httpCtx.IsPut() || httpCtx.IsPatch() { - if form, err := utils.ParseForm(httpCtx); err == nil { - maps.Copy(ctx.form, form) - } - } - - // Extract session data - session.AdvanceFlash() - ctx.session["id"] = session.ID - if session.IsEmpty() { - ctx.session["data"] = emptyMap - ctx.session["flash"] = emptyMap - } else { - ctx.session["data"] = session.GetAll() - ctx.session["flash"] = session.GetAllFlash() - } - - // Add environment vars - if envMgr := lualibs.GetGlobalEnvManager(); envMgr != nil { - maps.Copy(ctx.env, envMgr.GetAll()) - } - - // Populate Values with all context data - ctx.Values["method"] = string(httpCtx.Method()) - ctx.Values["path"] = string(httpCtx.Path()) - ctx.Values["host"] = string(httpCtx.Host()) - ctx.Values["headers"] = ctx.headers - ctx.Values["cookies"] = ctx.cookies - ctx.Values["query"] = ctx.query - ctx.Values["params"] = ctx.params - ctx.Values["form"] = ctx.form - ctx.Values["session"] = ctx.session - ctx.Values["env"] = ctx.env - - return ctx -} - -// Clear out all the request data from the context and give it back to the pool. -func (c *HTTPContext) Release() { - clear(c.Values) - clear(c.headers) - clear(c.cookies) - clear(c.query) - clear(c.params) - clear(c.form) - clear(c.session) - clear(c.env) - - httpContextPool.Put(c) -} - -// Add a value to the extras section -func (c *HTTPContext) Set(key string, value any) { - c.Values[key] = value -} - -// Get a value from the context -func (c *HTTPContext) Get(key string) any { - return c.Values[key] -} - -// Returns the Values map directly - zero overhead -func (c *HTTPContext) ToMap() map[string]any { - return c.Values -} diff --git a/runner/lua/cookie.lua b/runner/lua/cookie.lua deleted file mode 100644 index e3e4603..0000000 --- a/runner/lua/cookie.lua +++ /dev/null @@ -1,26 +0,0 @@ --- cookie.lua - -function cookie_set(name, value, options) - __response.cookies = __response.cookies or {} - local opts = options or {} - local cookie = { - name = name, - value = value or "", - path = opts.path or "/", - domain = opts.domain, - secure = opts.secure ~= false, - http_only = opts.http_only ~= false - } - if opts.expires and opts.expires > 0 then - cookie.max_age = opts.expires - end - table.insert(__response.cookies, cookie) -end - -function cookie_get(name) - return __ctx.cookies and __ctx.cookies[name] -end - -function cookie_delete(name, path, domain) - return cookie_set(name, "", {expires = -1, path = path or "/", domain = domain}) -end diff --git a/runner/lua/crypto.lua b/runner/lua/crypto.lua deleted file mode 100644 index c25f6d2..0000000 --- a/runner/lua/crypto.lua +++ /dev/null @@ -1,140 +0,0 @@ --- crypto.lua - --- ====================================================================== --- HASHING FUNCTIONS --- ====================================================================== - --- Generate hash digest using various algorithms --- Algorithms: md5, sha1, sha256, sha512 --- Formats: hex (default), binary -function hash(data, algorithm, format) - if type(data) ~= "string" then - error("hash: data must be a string", 2) - end - - algorithm = algorithm or "sha256" - format = format or "hex" - - return __crypto_hash(data, algorithm, format) -end - -function md5(data, format) - return hash(data, "md5", format) -end - -function sha1(data, format) - return hash(data, "sha1", format) -end - -function sha256(data, format) - return hash(data, "sha256", format) -end - -function sha512(data, format) - return hash(data, "sha512", format) -end - --- ====================================================================== --- HMAC FUNCTIONS --- ====================================================================== - --- Generate HMAC using various algorithms --- Algorithms: md5, sha1, sha256, sha512 --- Formats: hex (default), binary -function hmac(data, key, algorithm, format) - if type(data) ~= "string" then - error("hmac: data must be a string", 2) - end - - if type(key) ~= "string" then - error("hmac: key must be a string", 2) - end - - algorithm = algorithm or "sha256" - format = format or "hex" - - return __crypto_hmac(data, key, algorithm, format) -end - -function hmac_md5(data, key, format) - return hmac(data, key, "md5", format) -end - -function hmac_sha1(data, key, format) - return hmac(data, key, "sha1", format) -end - -function hmac_sha256(data, key, format) - return hmac(data, key, "sha256", format) -end - -function hmac_sha512(data, key, format) - return hmac(data, key, "sha512", format) -end - --- ====================================================================== --- RANDOM FUNCTIONS --- ====================================================================== - --- Generate random bytes --- Formats: binary (default), hex -function random_bytes(length, secure, format) - if type(length) ~= "number" or length <= 0 then - error("random_bytes: length must be positive", 2) - end - - secure = secure ~= false -- Default to secure - format = format or "binary" - - return __crypto_random_bytes(length, secure, format) -end - --- Generate random integer in range [min, max] -function random_int(min, max, secure) - if type(min) ~= "number" or type(max) ~= "number" then - error("random_int: min and max must be numbers", 2) - end - - if max <= min then - error("random_int: max must be greater than min", 2) - end - - secure = secure ~= false -- Default to secure - - return __crypto_random_int(min, max, secure) -end - --- Generate random string of specified length -function random_string(length, charset, secure) - if type(length) ~= "number" or length <= 0 then - error("random_string: length must be positive", 2) - end - - secure = secure ~= false -- Default to secure - - -- Default character set: alphanumeric - charset = charset or "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" - - if type(charset) ~= "string" or #charset == 0 then - error("random_string: charset must be non-empty", 2) - end - - local result = "" - local charset_length = #charset - - for i = 1, length do - local index = random_int(1, charset_length, secure) - result = result .. charset:sub(index, index) - end - - return result -end - --- ====================================================================== --- UUID FUNCTIONS --- ====================================================================== - --- Generate random UUID (v4) -function uuid() - return __crypto_uuid() -end diff --git a/runner/lua/csrf.lua b/runner/lua/csrf.lua deleted file mode 100644 index 2bac32a..0000000 --- a/runner/lua/csrf.lua +++ /dev/null @@ -1,33 +0,0 @@ --- csrf.lua - -function csrf_generate() - local token = generate_token(32) - session_set("_csrf_token", token) - return token -end - -function csrf_field() - local token = session_get("_csrf_token") - if not token then - token = csrf_generate() - end - return string.format('', - html_special_chars(token)) -end - -function csrf_validate() - local token = __ctx.session and __ctx.session.data and __ctx.session.data["_csrf_token"] - if not token then - __response.status = 403 - coroutine.yield("__EXIT__") - end - - local request_token = (__ctx._request_form and __ctx._request_form._csrf_token) or - (__ctx._request_headers and (__ctx._request_headers["x-csrf-token"] or __ctx._request_headers["csrf-token"])) - - if not request_token or request_token ~= token then - __response.status = 403 - coroutine.yield("__EXIT__") - end - return true -end diff --git a/runner/lua/env.lua b/runner/lua/env.lua deleted file mode 100644 index 047e4a5..0000000 --- a/runner/lua/env.lua +++ /dev/null @@ -1,89 +0,0 @@ --- env.lua - --- Get an environment variable with a default value -function env_get(key, default_value) - if type(key) ~= "string" then - error("env_get: key must be a string") - end - - -- Check context for environment variables - if __ctx and __ctx.env and __ctx.env[key] ~= nil then - return __ctx.env[key] - end - - return default_value -end - --- Set an environment variable -function env_set(key, value) - if type(key) ~= "string" then - error("env_set: key must be a string") - end - - -- Update context immediately for future reads - if __ctx then - __ctx.env = __ctx.env or {} - __ctx.env[key] = value - end - - -- Persist to Go backend - return __env_set(key, value) -end - --- Get all environment variables as a table -function env_get_all() - -- Return context table directly if available - if __ctx and __ctx.env then - local copy = {} - for k, v in pairs(__ctx.env) do - copy[k] = v - end - return copy - end - - -- Fallback to Go call - return __env_get_all() -end - --- Check if an environment variable exists -function env_exists(key) - if type(key) ~= "string" then - error("env_exists: key must be a string") - end - - -- Check context first - if __ctx and __ctx.env then - return __ctx.env[key] ~= nil - end - - return false -end - --- Set multiple environment variables from a table -function env_set_many(vars) - if type(vars) ~= "table" then - error("env_set_many: vars must be a table") - end - - if __ctx then - __ctx.env = __ctx.env or {} - end - - local success = true - for key, value in pairs(vars) do - if type(key) == "string" then - -- Update context - if __ctx and __ctx.env then - __ctx.env[key] = value - end - -- Persist to Go - if not __env_set(key, value) then - success = false - end - else - error("env_set_many: all keys must be strings") - end - end - - return success -end diff --git a/runner/lua/fs.lua b/runner/lua/fs.lua deleted file mode 100644 index b86d5ec..0000000 --- a/runner/lua/fs.lua +++ /dev/null @@ -1,136 +0,0 @@ --- fs.lua - -function fs_read(path) - if type(path) ~= "string" then - error("fs_read: path must be a string", 2) - end - return __fs_read_file(path) -end - -function fs_write(path, content) - if type(path) ~= "string" then - error("fs_write: path must be a string", 2) - end - if type(content) ~= "string" then - error("fs_write: content must be a string", 2) - end - return __fs_write_file(path, content) -end - -function fs_append(path, content) - if type(path) ~= "string" then - error("fs_append: path must be a string", 2) - end - if type(content) ~= "string" then - error("fs_append: content must be a string", 2) - end - return __fs_append_file(path, content) -end - -function fs_exists(path) - if type(path) ~= "string" then - error("fs_exists: path must be a string", 2) - end - return __fs_exists(path) -end - -function fs_remove(path) - if type(path) ~= "string" then - error("fs_remove: path must be a string", 2) - end - return __fs_remove_file(path) -end - -function fs_info(path) - if type(path) ~= "string" then - error("fs_info: path must be a string", 2) - end - local info = __fs_get_info(path) - - -- Convert the Unix timestamp to a readable date - if info and info.mod_time then - info.mod_time_str = os.date("%Y-%m-%d %H:%M:%S", info.mod_time) - end - - return info -end - --- Directory Operations -function fs_mkdir(path, mode) - if type(path) ~= "string" then - error("fs_mkdir: path must be a string", 2) - end - mode = mode or 0755 - return __fs_make_dir(path, mode) -end - -function fs_ls(path) - if type(path) ~= "string" then - error("fs_ls: path must be a string", 2) - end - return __fs_list_dir(path) -end - -function fs_rmdir(path, recursive) - if type(path) ~= "string" then - error("fs_rmdir: path must be a string", 2) - end - recursive = recursive or false - return __fs_remove_dir(path, recursive) -end - --- Path Operations -function fs_join_paths(...) - return __fs_join_paths(...) -end - -function fs_dir_name(path) - if type(path) ~= "string" then - error("fs_dir_name: path must be a string", 2) - end - return __fs_dir_name(path) -end - -function fs_base_name(path) - if type(path) ~= "string" then - error("fs_base_name: path must be a string", 2) - end - return __fs_base_name(path) -end - -function fs_extension(path) - if type(path) ~= "string" then - error("fs_extension: path must be a string", 2) - end - return __fs_extension(path) -end - --- Utility Functions -function fs_read_json(path) - local content = fs_read(path) - if not content then - return nil, "Could not read file" - end - - local ok, result = pcall(json.decode, content) - if not ok then - return nil, "Invalid JSON: " .. tostring(result) - end - - return result -end - -function fs_write_json(path, data, pretty) - if type(data) ~= "table" then - error("fs_write_json: data must be a table", 2) - end - - local content - if pretty then - content = json.pretty_print(data) - else - content = json.encode(data) - end - - return fs_write(path, content) -end diff --git a/runner/lua/http.lua b/runner/lua/http.lua deleted file mode 100644 index 0ca23b5..0000000 --- a/runner/lua/http.lua +++ /dev/null @@ -1,72 +0,0 @@ --- http.lua - -function http_set_status(code) - __response.status = code -end - -function http_set_header(name, value) - __response.headers = __response.headers or {} - __response.headers[name] = value -end - -function http_set_content_type(ct) - __response.headers = __response.headers or {} - __response.headers["Content-Type"] = ct -end - -function http_set_metadata(key, value) - __response.metadata = __response.metadata or {} - __response.metadata[key] = value -end - -function http_redirect(url, status) - __response.status = status or 302 - __response.headers = __response.headers or {} - __response.headers["Location"] = url - coroutine.yield("__EXIT__") -end - -function send_html(content) - http_set_content_type("text/html") - return content -end - -function send_json(content) - http_set_content_type("application/json") - return content -end - -function send_text(content) - http_set_content_type("text/plain") - return content -end - -function send_xml(content) - http_set_content_type("application/xml") - return content -end - -function send_javascript(content) - http_set_content_type("application/javascript") - return content -end - -function send_css(content) - http_set_content_type("text/css") - return content -end - -function send_svg(content) - http_set_content_type("image/svg+xml") - return content -end - -function send_csv(content) - http_set_content_type("text/csv") - return content -end - -function send_binary(content, mime_type) - http_set_content_type(mime_type or "application/octet-stream") - return content -end diff --git a/runner/lua/json.lua b/runner/lua/json.lua deleted file mode 100644 index 7ea35ba..0000000 --- a/runner/lua/json.lua +++ /dev/null @@ -1,421 +0,0 @@ --- json.lua - -local escape_chars = { - ['"'] = '\\"', ['\\'] = '\\\\', - ['\n'] = '\\n', ['\r'] = '\\r', ['\t'] = '\\t' -} - -function json_go_encode(value) - return __json_marshal(value) -end - -function json_go_decode(str) - if type(str) ~= "string" then - error("json_decode: expected string, got " .. type(str), 2) - end - return __json_unmarshal(str) -end - -function json_encode(data) - local t = type(data) - - if t == "nil" then return "null" end - if t == "boolean" then return data and "true" or "false" end - if t == "number" then return tostring(data) end - - if t == "string" then - return '"' .. data:gsub('[\\"\n\r\t]', escape_chars) .. '"' - end - - if t == "table" then - local isArray = true - local count = 0 - - -- Check if it's an array in one pass - for k, _ in pairs(data) do - count = count + 1 - if type(k) ~= "number" or k ~= count or k < 1 then - isArray = false - break - end - end - - if isArray then - local result = {} - for i = 1, count do - result[i] = json_encode(data[i]) - end - return "[" .. table.concat(result, ",") .. "]" - else - local result = {} - local index = 1 - for k, v in pairs(data) do - if type(k) == "string" and type(v) ~= "function" and type(v) ~= "userdata" then - result[index] = json_encode(k) .. ":" .. json_encode(v) - index = index + 1 - end - end - return "{" .. table.concat(result, ",") .. "}" - end - end - - return "null" -- Unsupported type -end - -function json_decode(data) - local pos = 1 - local len = #data - - -- Pre-compute byte values - local b_space = string.byte(' ') - local b_tab = string.byte('\t') - local b_cr = string.byte('\r') - local b_lf = string.byte('\n') - local b_quote = string.byte('"') - local b_backslash = string.byte('\\') - local b_slash = string.byte('/') - local b_lcurly = string.byte('{') - local b_rcurly = string.byte('}') - local b_lbracket = string.byte('[') - local b_rbracket = string.byte(']') - local b_colon = string.byte(':') - local b_comma = string.byte(',') - local b_0 = string.byte('0') - local b_9 = string.byte('9') - local b_minus = string.byte('-') - local b_plus = string.byte('+') - local b_dot = string.byte('.') - local b_e = string.byte('e') - local b_E = string.byte('E') - - -- Skip whitespace more efficiently - local function skip() - local b - while pos <= len do - b = data:byte(pos) - if b > b_space or (b ~= b_space and b ~= b_tab and b ~= b_cr and b ~= b_lf) then - break - end - pos = pos + 1 - end - end - - -- Forward declarations - local parse_value, parse_string, parse_number, parse_object, parse_array - - -- Parse a string more efficiently - parse_string = function() - pos = pos + 1 -- Skip opening quote - - if pos > len then - error("Unterminated string") - end - - -- Use a table to build the string - local result = {} - local result_pos = 1 - local start = pos - local c, b - - while pos <= len do - b = data:byte(pos) - - if b == b_backslash then - -- Add the chunk before the escape character - if pos > start then - result[result_pos] = data:sub(start, pos - 1) - result_pos = result_pos + 1 - end - - pos = pos + 1 - if pos > len then - error("Unterminated string escape") - end - - c = data:byte(pos) - if c == b_quote then - result[result_pos] = '"' - elseif c == b_backslash then - result[result_pos] = '\\' - elseif c == b_slash then - result[result_pos] = '/' - elseif c == string.byte('b') then - result[result_pos] = '\b' - elseif c == string.byte('f') then - result[result_pos] = '\f' - elseif c == string.byte('n') then - result[result_pos] = '\n' - elseif c == string.byte('r') then - result[result_pos] = '\r' - elseif c == string.byte('t') then - result[result_pos] = '\t' - else - result[result_pos] = data:sub(pos, pos) - end - - result_pos = result_pos + 1 - pos = pos + 1 - start = pos - elseif b == b_quote then - -- Add the final chunk - if pos > start then - result[result_pos] = data:sub(start, pos - 1) - result_pos = result_pos + 1 - end - - pos = pos + 1 - return table.concat(result) - else - pos = pos + 1 - end - end - - error("Unterminated string") - end - - -- Parse a number more efficiently - parse_number = function() - local start = pos - local b = data:byte(pos) - - -- Skip any sign - if b == b_minus then - pos = pos + 1 - if pos > len then - error("Malformed number") - end - b = data:byte(pos) - end - - -- Integer part - if b < b_0 or b > b_9 then - error("Malformed number") - end - - repeat - pos = pos + 1 - if pos > len then break end - b = data:byte(pos) - until b < b_0 or b > b_9 - - -- Fractional part - if pos <= len and b == b_dot then - pos = pos + 1 - if pos > len or data:byte(pos) < b_0 or data:byte(pos) > b_9 then - error("Malformed number") - end - - repeat - pos = pos + 1 - if pos > len then break end - b = data:byte(pos) - until b < b_0 or b > b_9 - end - - -- Exponent - if pos <= len and (b == b_e or b == b_E) then - pos = pos + 1 - if pos > len then - error("Malformed number") - end - - b = data:byte(pos) - if b == b_plus or b == b_minus then - pos = pos + 1 - if pos > len then - error("Malformed number") - end - b = data:byte(pos) - end - - if b < b_0 or b > b_9 then - error("Malformed number") - end - - repeat - pos = pos + 1 - if pos > len then break end - b = data:byte(pos) - until b < b_0 or b > b_9 - end - - return tonumber(data:sub(start, pos - 1)) - end - - -- Parse an object more efficiently - parse_object = function() - pos = pos + 1 -- Skip opening brace - local obj = {} - - skip() - if pos <= len and data:byte(pos) == b_rcurly then - pos = pos + 1 - return obj - end - - while pos <= len do - skip() - - if data:byte(pos) ~= b_quote then - error("Expected string key") - end - - local key = parse_string() - skip() - - if data:byte(pos) ~= b_colon then - error("Expected colon") - end - pos = pos + 1 - - obj[key] = parse_value() - skip() - - local b = data:byte(pos) - if b == b_rcurly then - pos = pos + 1 - return obj - end - - if b ~= b_comma then - error("Expected comma or closing brace") - end - pos = pos + 1 - end - - error("Unterminated object") - end - - -- Parse an array more efficiently - parse_array = function() - pos = pos + 1 -- Skip opening bracket - local arr = {} - local index = 1 - - skip() - if pos <= len and data:byte(pos) == b_rbracket then - pos = pos + 1 - return arr - end - - while pos <= len do - arr[index] = parse_value() - index = index + 1 - - skip() - - local b = data:byte(pos) - if b == b_rbracket then - pos = pos + 1 - return arr - end - - if b ~= b_comma then - error("Expected comma or closing bracket") - end - pos = pos + 1 - end - - error("Unterminated array") - end - - -- Parse a value more efficiently - parse_value = function() - skip() - - if pos > len then - error("Unexpected end of input") - end - - local b = data:byte(pos) - - if b == b_quote then - return parse_string() - elseif b == b_lcurly then - return parse_object() - elseif b == b_lbracket then - return parse_array() - elseif b == string.byte('n') and pos + 3 <= len and data:sub(pos, pos + 3) == "null" then - pos = pos + 4 - return nil - elseif b == string.byte('t') and pos + 3 <= len and data:sub(pos, pos + 3) == "true" then - pos = pos + 4 - return true - elseif b == string.byte('f') and pos + 4 <= len and data:sub(pos, pos + 4) == "false" then - pos = pos + 5 - return false - elseif b == b_minus or (b >= b_0 and b <= b_9) then - return parse_number() - else - error("Unexpected character: " .. string.char(b)) - end - end - - skip() - local result = parse_value() - skip() - - if pos <= len then - error("Unexpected trailing characters") - end - - return result -end - -function json_is_valid(str) - if type(str) ~= "string" then return false end - local status, _ = pcall(json_decode, str) - return status -end - -function json_pretty_print(value) - if type(value) == "string" then - value = json_decode(value) - end - - local function stringify(val, indent, visited) - visited = visited or {} - indent = indent or 0 - local spaces = string.rep(" ", indent) - - if type(val) == "table" then - if visited[val] then return "{...}" end - visited[val] = true - - local isArray = true - local i = 1 - for k in pairs(val) do - if type(k) ~= "number" or k ~= i then - isArray = false - break - end - i = i + 1 - end - - local result = isArray and "[\n" or "{\n" - local first = true - - if isArray then - for i, v in ipairs(val) do - if not first then result = result .. ",\n" end - first = false - result = result .. spaces .. " " .. stringify(v, indent + 1, visited) - end - else - for k, v in pairs(val) do - if not first then result = result .. ",\n" end - first = false - result = result .. spaces .. " \"" .. tostring(k) .. "\": " .. stringify(v, indent + 1, visited) - end - end - - return result .. "\n" .. spaces .. (isArray and "]" or "}") - elseif type(val) == "string" then - return "\"" .. val:gsub('\\', '\\\\'):gsub('"', '\\"'):gsub('\n', '\\n') .. "\"" - else - return tostring(val) - end - end - - return stringify(value) -end diff --git a/runner/lua/math.lua b/runner/lua/math.lua deleted file mode 100644 index c48dca9..0000000 --- a/runner/lua/math.lua +++ /dev/null @@ -1,800 +0,0 @@ --- math.lua - -local math_ext = {} - --- Import standard math functions -for name, func in pairs(_G.math) do - math_ext[name] = func -end - --- ====================================================================== --- CONSTANTS (higher precision) --- ====================================================================== - -math_ext.pi = 3.14159265358979323846 -math_ext.tau = 6.28318530717958647693 -- 2*pi -math_ext.e = 2.71828182845904523536 -math_ext.phi = 1.61803398874989484820 -- Golden ratio -math_ext.sqrt2 = 1.41421356237309504880 -math_ext.sqrt3 = 1.73205080756887729353 -math_ext.ln2 = 0.69314718055994530942 -math_ext.ln10 = 2.30258509299404568402 -math_ext.infinity = 1/0 -math_ext.nan = 0/0 - --- ====================================================================== --- EXTENDED FUNCTIONS --- ====================================================================== - --- Cube root (handles negative numbers correctly) -function math_ext.cbrt(x) - return x < 0 and -(-x)^(1/3) or x^(1/3) -end - --- Hypotenuse of right-angled triangle -function math_ext.hypot(x, y) - return math.sqrt(x * x + y * y) -end - --- Check if value is NaN -function math_ext.isnan(x) - return x ~= x -end - --- Check if value is finite -function math_ext.isfinite(x) - return x > -math_ext.infinity and x < math_ext.infinity -end - --- Sign function (-1, 0, 1) -function math_ext.sign(x) - return x > 0 and 1 or (x < 0 and -1 or 0) -end - --- Clamp value between min and max -function math_ext.clamp(x, min, max) - return x < min and min or (x > max and max or x) -end - --- Linear interpolation -function math_ext.lerp(a, b, t) - return a + (b - a) * t -end - --- Smooth step interpolation -function math_ext.smoothstep(a, b, t) - t = math_ext.clamp((t - a) / (b - a), 0, 1) - return t * t * (3 - 2 * t) -end - --- Map value from one range to another -function math_ext.map(x, in_min, in_max, out_min, out_max) - return (x - in_min) * (out_max - out_min) / (in_max - in_min) + out_min -end - --- Round to nearest integer -function math_ext.round(x) - return x >= 0 and math.floor(x + 0.5) or math.ceil(x - 0.5) -end - --- Round to specified decimal places -function math_ext.roundto(x, decimals) - local mult = 10 ^ (decimals or 0) - return math.floor(x * mult + 0.5) / mult -end - --- Normalize angle to [-π, π] -function math_ext.normalize_angle(angle) - return angle - 2 * math_ext.pi * math.floor((angle + math_ext.pi) / (2 * math_ext.pi)) -end - --- Distance between points -function math_ext.distance(x1, y1, x2, y2) - local dx, dy = x2 - x1, y2 - y1 - return math.sqrt(dx * dx + dy * dy) -end - --- ====================================================================== --- RANDOM NUMBER FUNCTIONS --- ====================================================================== - --- Random float in range [min, max) -function math_ext.randomf(min, max) - if not min and not max then - return math.random() - elseif not max then - max = min - min = 0 - end - return min + math.random() * (max - min) -end - --- Random integer in range [min, max] -function math_ext.randint(min, max) - if not max then - max = min - min = 1 - end - return math.floor(math.random() * (max - min + 1) + min) -end - --- Random boolean with probability p (default 0.5) -function math_ext.randboolean(p) - p = p or 0.5 - return math.random() < p -end - --- ====================================================================== --- STATISTICS FUNCTIONS --- ====================================================================== - --- Sum of values -function math_ext.sum(t) - if type(t) ~= "table" then return 0 end - - local sum = 0 - for i=1, #t do - if type(t[i]) == "number" then - sum = sum + t[i] - end - end - return sum -end - --- Mean (average) of values -function math_ext.mean(t) - if type(t) ~= "table" or #t == 0 then return 0 end - - local sum = 0 - local count = 0 - for i=1, #t do - if type(t[i]) == "number" then - sum = sum + t[i] - count = count + 1 - end - end - return count > 0 and sum / count or 0 -end - --- Median of values -function math_ext.median(t) - if type(t) ~= "table" or #t == 0 then return 0 end - - local nums = {} - local count = 0 - for i=1, #t do - if type(t[i]) == "number" then - count = count + 1 - nums[count] = t[i] - end - end - - if count == 0 then return 0 end - - table.sort(nums) - - if count % 2 == 0 then - return (nums[count/2] + nums[count/2 + 1]) / 2 - else - return nums[math.ceil(count/2)] - end -end - --- Variance of values -function math_ext.variance(t) - if type(t) ~= "table" then return 0 end - - local count = 0 - local m = math_ext.mean(t) - local sum = 0 - - for i=1, #t do - if type(t[i]) == "number" then - local dev = t[i] - m - sum = sum + dev * dev - count = count + 1 - end - end - - return count > 1 and sum / count or 0 -end - --- Standard deviation -function math_ext.stdev(t) - return math.sqrt(math_ext.variance(t)) -end - --- Population variance -function math_ext.pvariance(t) - if type(t) ~= "table" then return 0 end - - local count = 0 - local m = math_ext.mean(t) - local sum = 0 - - for i=1, #t do - if type(t[i]) == "number" then - local dev = t[i] - m - sum = sum + dev * dev - count = count + 1 - end - end - - return count > 0 and sum / count or 0 -end - --- Population standard deviation -function math_ext.pstdev(t) - return math.sqrt(math_ext.pvariance(t)) -end - --- Mode (most common value) -function math_ext.mode(t) - if type(t) ~= "table" or #t == 0 then return nil end - - local counts = {} - local most_frequent = nil - local max_count = 0 - - for i=1, #t do - local v = t[i] - counts[v] = (counts[v] or 0) + 1 - if counts[v] > max_count then - max_count = counts[v] - most_frequent = v - end - end - - return most_frequent -end - --- Min and max simultaneously (faster than calling both separately) -function math_ext.minmax(t) - if type(t) ~= "table" or #t == 0 then return nil, nil end - - local min, max - for i=1, #t do - if type(t[i]) == "number" then - min = t[i] - max = t[i] - break - end - end - - if min == nil then return nil, nil end - - for i=1, #t do - if type(t[i]) == "number" then - if t[i] < min then min = t[i] end - if t[i] > max then max = t[i] end - end - end - - return min, max -end - --- ====================================================================== --- VECTOR OPERATIONS (2D/3D vectors) --- ====================================================================== - --- 2D Vector operations -math_ext.vec2 = { - new = function(x, y) - return {x = x or 0, y = y or 0} - end, - - copy = function(v) - return {x = v.x, y = v.y} - end, - - add = function(a, b) - return {x = a.x + b.x, y = a.y + b.y} - end, - - sub = function(a, b) - return {x = a.x - b.x, y = a.y - b.y} - end, - - mul = function(a, b) - if type(b) == "number" then - return {x = a.x * b, y = a.y * b} - end - return {x = a.x * b.x, y = a.y * b.y} - end, - - div = function(a, b) - if type(b) == "number" then - local inv = 1 / b - return {x = a.x * inv, y = a.y * inv} - end - return {x = a.x / b.x, y = a.y / b.y} - end, - - dot = function(a, b) - return a.x * b.x + a.y * b.y - end, - - length = function(v) - return math.sqrt(v.x * v.x + v.y * v.y) - end, - - length_squared = function(v) - return v.x * v.x + v.y * v.y - end, - - distance = function(a, b) - local dx, dy = b.x - a.x, b.y - a.y - return math.sqrt(dx * dx + dy * dy) - end, - - distance_squared = function(a, b) - local dx, dy = b.x - a.x, b.y - a.y - return dx * dx + dy * dy - end, - - normalize = function(v) - local len = math.sqrt(v.x * v.x + v.y * v.y) - if len > 1e-10 then - local inv_len = 1 / len - return {x = v.x * inv_len, y = v.y * inv_len} - end - return {x = 0, y = 0} - end, - - rotate = function(v, angle) - local c, s = math.cos(angle), math.sin(angle) - return { - x = v.x * c - v.y * s, - y = v.x * s + v.y * c - } - end, - - angle = function(v) - return math.atan2(v.y, v.x) - end, - - lerp = function(a, b, t) - t = math_ext.clamp(t, 0, 1) - return { - x = a.x + (b.x - a.x) * t, - y = a.y + (b.y - a.y) * t - } - end, - - reflect = function(v, normal) - local dot = v.x * normal.x + v.y * normal.y - return { - x = v.x - 2 * dot * normal.x, - y = v.y - 2 * dot * normal.y - } - end -} - --- 3D Vector operations -math_ext.vec3 = { - new = function(x, y, z) - return {x = x or 0, y = y or 0, z = z or 0} - end, - - copy = function(v) - return {x = v.x, y = v.y, z = v.z} - end, - - add = function(a, b) - return {x = a.x + b.x, y = a.y + b.y, z = a.z + b.z} - end, - - sub = function(a, b) - return {x = a.x - b.x, y = a.y - b.y, z = a.z - b.z} - end, - - mul = function(a, b) - if type(b) == "number" then - return {x = a.x * b, y = a.y * b, z = a.z * b} - end - return {x = a.x * b.x, y = a.y * b.y, z = a.z * b.z} - end, - - div = function(a, b) - if type(b) == "number" then - local inv = 1 / b - return {x = a.x * inv, y = a.y * inv, z = a.z * inv} - end - return {x = a.x / b.x, y = a.y / b.y, z = a.z / b.z} - end, - - dot = function(a, b) - return a.x * b.x + a.y * b.y + a.z * b.z - end, - - cross = function(a, b) - return { - x = a.y * b.z - a.z * b.y, - y = a.z * b.x - a.x * b.z, - z = a.x * b.y - a.y * b.x - } - end, - - length = function(v) - return math.sqrt(v.x * v.x + v.y * v.y + v.z * v.z) - end, - - length_squared = function(v) - return v.x * v.x + v.y * v.y + v.z * v.z - end, - - distance = function(a, b) - local dx, dy, dz = b.x - a.x, b.y - a.y, b.z - a.z - return math.sqrt(dx * dx + dy * dy + dz * dz) - end, - - distance_squared = function(a, b) - local dx, dy, dz = b.x - a.x, b.y - a.y, b.z - a.z - return dx * dx + dy * dy + dz * dz - end, - - normalize = function(v) - local len = math.sqrt(v.x * v.x + v.y * v.y + v.z * v.z) - if len > 1e-10 then - local inv_len = 1 / len - return {x = v.x * inv_len, y = v.y * inv_len, z = v.z * inv_len} - end - return {x = 0, y = 0, z = 0} - end, - - lerp = function(a, b, t) - t = math_ext.clamp(t, 0, 1) - return { - x = a.x + (b.x - a.x) * t, - y = a.y + (b.y - a.y) * t, - z = a.z + (b.z - a.z) * t - } - end, - - reflect = function(v, normal) - local dot = v.x * normal.x + v.y * normal.y + v.z * normal.z - return { - x = v.x - 2 * dot * normal.x, - y = v.y - 2 * dot * normal.y, - z = v.z - 2 * dot * normal.z - } - end -} - --- ====================================================================== --- MATRIX OPERATIONS (2x2 and 3x3 matrices) --- ====================================================================== - -math_ext.mat2 = { - -- Create a new 2x2 matrix - new = function(a, b, c, d) - return { - {a or 1, b or 0}, - {c or 0, d or 1} - } - end, - - -- Create identity matrix - identity = function() - return {{1, 0}, {0, 1}} - end, - - -- Matrix multiplication - mul = function(a, b) - return { - { - a[1][1] * b[1][1] + a[1][2] * b[2][1], - a[1][1] * b[1][2] + a[1][2] * b[2][2] - }, - { - a[2][1] * b[1][1] + a[2][2] * b[2][1], - a[2][1] * b[1][2] + a[2][2] * b[2][2] - } - } - end, - - -- Determinant - det = function(m) - return m[1][1] * m[2][2] - m[1][2] * m[2][1] - end, - - -- Inverse matrix - inverse = function(m) - local det = m[1][1] * m[2][2] - m[1][2] * m[2][1] - if math.abs(det) < 1e-10 then - return nil -- Matrix is not invertible - end - - local inv_det = 1 / det - return { - {m[2][2] * inv_det, -m[1][2] * inv_det}, - {-m[2][1] * inv_det, m[1][1] * inv_det} - } - end, - - -- Rotation matrix - rotation = function(angle) - local cos, sin = math.cos(angle), math.sin(angle) - return { - {cos, -sin}, - {sin, cos} - } - end, - - -- Apply matrix to vector - transform = function(m, v) - return { - x = m[1][1] * v.x + m[1][2] * v.y, - y = m[2][1] * v.x + m[2][2] * v.y - } - end, - - -- Scale matrix - scale = function(sx, sy) - sy = sy or sx - return { - {sx, 0}, - {0, sy} - } - end -} - -math_ext.mat3 = { - -- Create identity matrix 3x3 - identity = function() - return { - {1, 0, 0}, - {0, 1, 0}, - {0, 0, 1} - } - end, - - -- Create a 2D transformation matrix (translation, rotation, scale) - transform = function(x, y, angle, sx, sy) - sx = sx or 1 - sy = sy or sx - local cos, sin = math.cos(angle), math.sin(angle) - return { - {cos * sx, -sin * sy, x}, - {sin * sx, cos * sy, y}, - {0, 0, 1} - } - end, - - -- Matrix multiplication - mul = function(a, b) - local result = { - {0, 0, 0}, - {0, 0, 0}, - {0, 0, 0} - } - - for i = 1, 3 do - for j = 1, 3 do - for k = 1, 3 do - result[i][j] = result[i][j] + a[i][k] * b[k][j] - end - end - end - - return result - end, - - -- Apply matrix to point (homogeneous coordinates) - transform_point = function(m, v) - local x = m[1][1] * v.x + m[1][2] * v.y + m[1][3] - local y = m[2][1] * v.x + m[2][2] * v.y + m[2][3] - local w = m[3][1] * v.x + m[3][2] * v.y + m[3][3] - - if math.abs(w) < 1e-10 then - return {x = 0, y = 0} - end - - return {x = x / w, y = y / w} - end, - - -- Translation matrix - translation = function(x, y) - return { - {1, 0, x}, - {0, 1, y}, - {0, 0, 1} - } - end, - - -- Rotation matrix - rotation = function(angle) - local cos, sin = math.cos(angle), math.sin(angle) - return { - {cos, -sin, 0}, - {sin, cos, 0}, - {0, 0, 1} - } - end, - - -- Scale matrix - scale = function(sx, sy) - sy = sy or sx - return { - {sx, 0, 0}, - {0, sy, 0}, - {0, 0, 1} - } - end, - - -- Determinant - det = function(m) - return m[1][1] * (m[2][2] * m[3][3] - m[2][3] * m[3][2]) - - m[1][2] * (m[2][1] * m[3][3] - m[2][3] * m[3][1]) + - m[1][3] * (m[2][1] * m[3][2] - m[2][2] * m[3][1]) - end -} - --- ====================================================================== --- GEOMETRY FUNCTIONS --- ====================================================================== - -math_ext.geometry = { - -- Distance from point to line - point_line_distance = function(px, py, x1, y1, x2, y2) - local dx, dy = x2 - x1, y2 - y1 - local len_sq = dx * dx + dy * dy - - if len_sq < 1e-10 then - return math_ext.distance(px, py, x1, y1) - end - - local t = ((px - x1) * dx + (py - y1) * dy) / len_sq - t = math_ext.clamp(t, 0, 1) - - local nearestX = x1 + t * dx - local nearestY = y1 + t * dy - - return math_ext.distance(px, py, nearestX, nearestY) - end, - - -- Check if point is inside polygon - point_in_polygon = function(px, py, vertices) - local inside = false - local n = #vertices / 2 - - for i = 1, n do - local x1, y1 = vertices[i*2-1], vertices[i*2] - local x2, y2 - - if i == n then - x2, y2 = vertices[1], vertices[2] - else - x2, y2 = vertices[i*2+1], vertices[i*2+2] - end - - if ((y1 > py) ~= (y2 > py)) and - (px < (x2 - x1) * (py - y1) / (y2 - y1) + x1) then - inside = not inside - end - end - - return inside - end, - - -- Area of a triangle - triangle_area = function(x1, y1, x2, y2, x3, y3) - return math.abs((x1 * (y2 - y3) + x2 * (y3 - y1) + x3 * (y1 - y2)) / 2) - end, - - -- Check if point is inside triangle - point_in_triangle = function(px, py, x1, y1, x2, y2, x3, y3) - local area = math_ext.geometry.triangle_area(x1, y1, x2, y2, x3, y3) - local area1 = math_ext.geometry.triangle_area(px, py, x2, y2, x3, y3) - local area2 = math_ext.geometry.triangle_area(x1, y1, px, py, x3, y3) - local area3 = math_ext.geometry.triangle_area(x1, y1, x2, y2, px, py) - - return math.abs(area - (area1 + area2 + area3)) < 1e-10 - end, - - -- Check if two line segments intersect - line_intersect = function(x1, y1, x2, y2, x3, y3, x4, y4) - local d = (y4 - y3) * (x2 - x1) - (x4 - x3) * (y2 - y1) - - if math.abs(d) < 1e-10 then - return false, nil, nil -- Lines are parallel - end - - local ua = ((x4 - x3) * (y1 - y3) - (y4 - y3) * (x1 - x3)) / d - local ub = ((x2 - x1) * (y1 - y3) - (y2 - y1) * (x1 - x3)) / d - - if ua >= 0 and ua <= 1 and ub >= 0 and ub <= 1 then - local x = x1 + ua * (x2 - x1) - local y = y1 + ua * (y2 - y1) - return true, x, y - end - - return false, nil, nil - end, - - -- Closest point on line segment to point - closest_point_on_segment = function(px, py, x1, y1, x2, y2) - local dx, dy = x2 - x1, y2 - y1 - local len_sq = dx * dx + dy * dy - - if len_sq < 1e-10 then - return x1, y1 - end - - local t = ((px - x1) * dx + (py - y1) * dy) / len_sq - t = math_ext.clamp(t, 0, 1) - - return x1 + t * dx, y1 + t * dy - end -} - --- ====================================================================== --- INTERPOLATION FUNCTIONS --- ====================================================================== - -math_ext.interpolation = { - -- Cubic Bezier interpolation - bezier = function(t, p0, p1, p2, p3) - t = math_ext.clamp(t, 0, 1) - local t2 = t * t - local t3 = t2 * t - local mt = 1 - t - local mt2 = mt * mt - local mt3 = mt2 * mt - - return p0 * mt3 + 3 * p1 * mt2 * t + 3 * p2 * mt * t2 + p3 * t3 - end, - - -- Catmull-Rom spline interpolation - catmull_rom = function(t, p0, p1, p2, p3) - t = math_ext.clamp(t, 0, 1) - local t2 = t * t - local t3 = t2 * t - - return 0.5 * ( - (2 * p1) + - (-p0 + p2) * t + - (2 * p0 - 5 * p1 + 4 * p2 - p3) * t2 + - (-p0 + 3 * p1 - 3 * p2 + p3) * t3 - ) - end, - - -- Hermite interpolation - hermite = function(t, p0, p1, m0, m1) - t = math_ext.clamp(t, 0, 1) - local t2 = t * t - local t3 = t2 * t - local h00 = 2 * t3 - 3 * t2 + 1 - local h10 = t3 - 2 * t2 + t - local h01 = -2 * t3 + 3 * t2 - local h11 = t3 - t2 - - return h00 * p0 + h10 * m0 + h01 * p1 + h11 * m1 - end, - - -- Quadratic Bezier interpolation - quadratic_bezier = function(t, p0, p1, p2) - t = math_ext.clamp(t, 0, 1) - local mt = 1 - t - return mt * mt * p0 + 2 * mt * t * p1 + t * t * p2 - end, - - -- Step interpolation - step = function(t, edge, x) - return t < edge and 0 or x - end, - - -- Smoothstep interpolation - smoothstep = function(edge0, edge1, x) - local t = math_ext.clamp((x - edge0) / (edge1 - edge0), 0, 1) - return t * t * (3 - 2 * t) - end, - - -- Smootherstep interpolation (Ken Perlin) - smootherstep = function(edge0, edge1, x) - local t = math_ext.clamp((x - edge0) / (edge1 - edge0), 0, 1) - return t * t * t * (t * (t * 6 - 15) + 10) - end -} - -return math_ext diff --git a/runner/lua/render.lua b/runner/lua/render.lua deleted file mode 100644 index d38d367..0000000 --- a/runner/lua/render.lua +++ /dev/null @@ -1,190 +0,0 @@ --- render.lua - --- Template processing with code execution -function render(template_str, env) - local function is_control_structure(code) - -- Check if code is a control structure that doesn't produce output - local trimmed = code:match("^%s*(.-)%s*$") - return trimmed == "else" or - trimmed == "end" or - trimmed:match("^if%s") or - trimmed:match("^elseif%s") or - trimmed:match("^for%s") or - trimmed:match("^while%s") or - trimmed:match("^repeat%s*$") or - trimmed:match("^until%s") or - trimmed:match("^do%s*$") or - trimmed:match("^local%s") or - trimmed:match("^function%s") or - trimmed:match(".*=%s*function%s*%(") or - trimmed:match(".*then%s*$") or - trimmed:match(".*do%s*$") - end - - local pos, chunks = 1, {} - while pos <= #template_str do - local unescaped_start = template_str:find("{{{", pos, true) - local escaped_start = template_str:find("{{", pos, true) - - local start, tag_type, open_len - if unescaped_start and (not escaped_start or unescaped_start <= escaped_start) then - start, tag_type, open_len = unescaped_start, "-", 3 - elseif escaped_start then - start, tag_type, open_len = escaped_start, "=", 2 - else - table.insert(chunks, template_str:sub(pos)) - break - end - - if start > pos then - table.insert(chunks, template_str:sub(pos, start-1)) - end - - pos = start + open_len - local close_tag = tag_type == "-" and "}}}" or "}}" - local close_start, close_stop = template_str:find(close_tag, pos, true) - if not close_start then - error("Failed to find closing tag at position " .. pos) - end - - local code = template_str:sub(pos, close_start-1):match("^%s*(.-)%s*$") - local is_control = is_control_structure(code) - - table.insert(chunks, {tag_type, code, pos, is_control}) - pos = close_stop + 1 - end - - local buffer = {"local _tostring, _escape, _b, _b_i = ...\n"} - for _, chunk in ipairs(chunks) do - local t = type(chunk) - if t == "string" then - table.insert(buffer, "_b_i = _b_i + 1\n") - table.insert(buffer, "_b[_b_i] = " .. string.format("%q", chunk) .. "\n") - else - local tag_type, code, pos, is_control = chunk[1], chunk[2], chunk[3], chunk[4] - - if is_control then - -- Control structure - just insert as raw Lua code - table.insert(buffer, "--[[" .. pos .. "]] " .. code .. "\n") - elseif tag_type == "=" then - -- Simple variable check - if code:match("^[%w_]+$") then - table.insert(buffer, "_b_i = _b_i + 1\n") - table.insert(buffer, "--[[" .. pos .. "]] _b[_b_i] = _escape(_tostring(" .. code .. "))\n") - else - -- Expression output with escaping - table.insert(buffer, "_b_i = _b_i + 1\n") - table.insert(buffer, "--[[" .. pos .. "]] _b[_b_i] = _escape(_tostring(" .. code .. "))\n") - end - elseif tag_type == "-" then - -- Unescaped output - table.insert(buffer, "_b_i = _b_i + 1\n") - table.insert(buffer, "--[[" .. pos .. "]] _b[_b_i] = _tostring(" .. code .. ")\n") - end - end - end - table.insert(buffer, "return _b") - - local generated_code = table.concat(buffer) - - -- DEBUG: Uncomment to see generated code - -- print("Generated Lua code:") - -- print(generated_code) - -- print("---") - - local fn, err = loadstring(generated_code) - if not fn then - print("Generated code that failed to compile:") - print(generated_code) - error(err) - end - - env = env or {} - local runtime_env = setmetatable({}, {__index = function(_, k) return env[k] or _G[k] end}) - setfenv(fn, runtime_env) - - local output_buffer = {} - fn(tostring, html_special_chars, output_buffer, 0) - return table.concat(output_buffer) -end - --- Named placeholder processing -function parse(template_str, env) - local pos, output = 1, {} - env = env or {} - - while pos <= #template_str do - local unescaped_start, unescaped_end, unescaped_name = template_str:find("{{{%s*([%w_]+)%s*}}}", pos) - local escaped_start, escaped_end, escaped_name = template_str:find("{{%s*([%w_]+)%s*}}", pos) - - local next_pos, placeholder_end, name, escaped - if unescaped_start and (not escaped_start or unescaped_start <= escaped_start) then - next_pos, placeholder_end, name, escaped = unescaped_start, unescaped_end, unescaped_name, false - elseif escaped_start then - next_pos, placeholder_end, name, escaped = escaped_start, escaped_end, escaped_name, true - else - local text = template_str:sub(pos) - if text and #text > 0 then - table.insert(output, text) - end - break - end - - local text = template_str:sub(pos, next_pos - 1) - if text and #text > 0 then - table.insert(output, text) - end - - local value = env[name] - local str = tostring(value or "") - if escaped then - str = html_special_chars(str) - end - table.insert(output, str) - - pos = placeholder_end + 1 - end - - return table.concat(output) -end - --- Indexed placeholder processing -function iparse(template_str, values) - local pos, output, value_index = 1, {}, 1 - values = values or {} - - while pos <= #template_str do - local unescaped_start, unescaped_end = template_str:find("{{{}}}", pos, true) - local escaped_start, escaped_end = template_str:find("{{}}", pos, true) - - local next_pos, placeholder_end, escaped - if unescaped_start and (not escaped_start or unescaped_start <= escaped_start) then - next_pos, placeholder_end, escaped = unescaped_start, unescaped_end, false - elseif escaped_start then - next_pos, placeholder_end, escaped = escaped_start, escaped_end, true - else - local text = template_str:sub(pos) - if text and #text > 0 then - table.insert(output, text) - end - break - end - - local text = template_str:sub(pos, next_pos - 1) - if text and #text > 0 then - table.insert(output, text) - end - - local value = values[value_index] - local str = tostring(value or "") - if escaped then - str = html_special_chars(str) - end - table.insert(output, str) - - pos = placeholder_end + 1 - value_index = value_index + 1 - end - - return table.concat(output) -end diff --git a/runner/lua/sandbox.lua b/runner/lua/sandbox.lua deleted file mode 100644 index 5e9d3d2..0000000 --- a/runner/lua/sandbox.lua +++ /dev/null @@ -1,36 +0,0 @@ --- sandbox.lua - -function __execute(script_func, ctx, response) - -- Store context and response globally for function access - __ctx = ctx - __response = response - _G.ctx = ctx - - -- Create a coroutine for script execution to handle early exits - local co = coroutine.create(function() - return script_func() - end) - - local ok, result = coroutine.resume(co) - - -- Clean up - __ctx = nil - __response = nil - - if not ok then - -- Real error during script execution - error(result, 0) - end - - -- Check if exit was requested - if result == "__EXIT__" then - return {nil, response} - end - - return {result, response} -end - --- Exit sentinel using coroutine yield instead of error -function exit() - coroutine.yield("__EXIT__") -end diff --git a/runner/lua/session.lua b/runner/lua/session.lua deleted file mode 100644 index 837d689..0000000 --- a/runner/lua/session.lua +++ /dev/null @@ -1,179 +0,0 @@ --- session.lua - -function session_set(key, value) - __response.session = __response.session or {} - __response.session[key] = value - if __ctx.session and __ctx.session.data then - __ctx.session.data[key] = value - end -end - -function session_get(key) - return __ctx.session and __ctx.session.data and __ctx.session.data[key] -end - -function session_id() - return __ctx.session and __ctx.session.id -end - -function session_get_all() - if __ctx.session and __ctx.session.data then - local copy = {} - for k, v in pairs(__ctx.session.data) do - copy[k] = v - end - return copy - end - return {} -end - -function session_delete(key) - __response.session = __response.session or {} - __response.session[key] = "__DELETE__" - if __ctx.session and __ctx.session.data then - __ctx.session.data[key] = nil - end -end - -function session_clear() - __response.session = {__clear_all = true} - if __ctx.session and __ctx.session.data then - for k in pairs(__ctx.session.data) do - __ctx.session.data[k] = nil - end - end -end - -function session_flash(key, value) - __response.flash = __response.flash or {} - __response.flash[key] = value -end - -function session_get_flash(key) - -- Check current flash data first - if __response.flash and __response.flash[key] ~= nil then - return __response.flash[key] - end - - -- Check session flash data - if __ctx.session and __ctx.session.flash and __ctx.session.flash[key] ~= nil then - return __ctx.session.flash[key] - end - - return nil -end - -function session_has_flash(key) - -- Check current flash - if __response.flash and __response.flash[key] ~= nil then - return true - end - - -- Check session flash - if __ctx.session and __ctx.session.flash and __ctx.session.flash[key] ~= nil then - return true - end - - return false -end - -function session_get_all_flash() - local flash = {} - - -- Add session flash data first - if __ctx.session and __ctx.session.flash then - for k, v in pairs(__ctx.session.flash) do - flash[k] = v - end - end - - -- Add current response flash (overwrites session flash if same key) - if __response.flash then - for k, v in pairs(__response.flash) do - flash[k] = v - end - end - - return flash -end - -function session_flash_now(key, value) - -- Flash for current request only (not persisted) - _G._current_flash = _G._current_flash or {} - _G._current_flash[key] = value -end - -function session_get_flash_now(key) - return _G._current_flash and _G._current_flash[key] -end - --- ====================================================================== --- FLASH HELPER FUNCTIONS --- ====================================================================== - -function flash_success(message) - session_flash("success", message) -end - -function flash_error(message) - session_flash("error", message) -end - -function flash_warning(message) - session_flash("warning", message) -end - -function flash_info(message) - session_flash("info", message) -end - -function flash_message(type, message) - session_flash(type, message) -end - --- Get flash messages by type -function get_flash_success() - return session_get_flash("success") -end - -function get_flash_error() - return session_get_flash("error") -end - -function get_flash_warning() - return session_get_flash("warning") -end - -function get_flash_info() - return session_get_flash("info") -end - --- Check if flash messages exist -function has_flash_success() - return session_has_flash("success") -end - -function has_flash_error() - return session_has_flash("error") -end - -function has_flash_warning() - return session_has_flash("warning") -end - -function has_flash_info() - return session_has_flash("info") -end - -function redirect_with_flash(url, type, message, status) - session_flash(type or "info", message) - http_redirect(url, status) -end - -function redirect_with_success(url, message, status) - redirect_with_flash(url, "success", message, status) -end - -function redirect_with_error(url, message, status) - redirect_with_flash(url, "error", message, status) -end diff --git a/runner/lua/sqlite.lua b/runner/lua/sqlite.lua deleted file mode 100644 index 7b979b7..0000000 --- a/runner/lua/sqlite.lua +++ /dev/null @@ -1,299 +0,0 @@ --- sqlite.lua - -local function normalize_params(params, ...) - if type(params) == "table" then return params end - local args = {...} - if #args > 0 or params ~= nil then - table.insert(args, 1, params) - return args - end - return nil -end - -local connection_mt = { - __index = { - query = function(self, query, params, ...) - if type(query) ~= "string" then - error("connection:query: query must be a string", 2) - end - - local normalized_params = normalize_params(params, ...) - return __sqlite_query(self.db_name, query, normalized_params, __STATE_INDEX) - end, - - exec = function(self, query, params, ...) - if type(query) ~= "string" then - error("connection:exec: query must be a string", 2) - end - - local normalized_params = normalize_params(params, ...) - return __sqlite_exec(self.db_name, query, normalized_params, __STATE_INDEX) - end, - - get_one = function(self, query, params, ...) - if type(query) ~= "string" then - error("connection:get_one: query must be a string", 2) - end - - local normalized_params = normalize_params(params, ...) - return __sqlite_get_one(self.db_name, query, normalized_params, __STATE_INDEX) - end, - - insert = function(self, table_name, data, columns) - if type(data) ~= "table" then - error("connection:insert: data must be a table", 2) - end - - -- Single object: {col1=val1, col2=val2} - if data[1] == nil and next(data) ~= nil then - local cols = table.keys(data) - local placeholders = table.map(cols, function(_, i) return ":p" .. i end) - local params = {} - for i, col in ipairs(cols) do - params["p" .. i] = data[col] - end - - local query = string.format( - "INSERT INTO %s (%s) VALUES (%s)", - table_name, - table.concat(cols, ", "), - table.concat(placeholders, ", ") - ) - return self:exec(query, params) - end - - -- Array data with columns - if columns and type(columns) == "table" then - if #data > 0 and type(data[1]) == "table" then - -- Multiple rows - local value_groups = {} - local params = {} - local param_idx = 1 - - for _, row in ipairs(data) do - local row_placeholders = {} - for j = 1, #columns do - local param_name = "p" .. param_idx - table.insert(row_placeholders, ":" .. param_name) - params[param_name] = row[j] - param_idx = param_idx + 1 - end - table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")") - end - - local query = string.format( - "INSERT INTO %s (%s) VALUES %s", - table_name, - table.concat(columns, ", "), - table.concat(value_groups, ", ") - ) - return self:exec(query, params) - else - -- Single row array - local placeholders = table.map(columns, function(_, i) return ":p" .. i end) - local params = {} - for i = 1, #columns do - params["p" .. i] = data[i] - end - - local query = string.format( - "INSERT INTO %s (%s) VALUES (%s)", - table_name, - table.concat(columns, ", "), - table.concat(placeholders, ", ") - ) - return self:exec(query, params) - end - end - - -- Array of objects - if #data > 0 and type(data[1]) == "table" and data[1][1] == nil then - local cols = table.keys(data[1]) - local value_groups = {} - local params = {} - local param_idx = 1 - - for _, row in ipairs(data) do - local row_placeholders = {} - for _, col in ipairs(cols) do - local param_name = "p" .. param_idx - table.insert(row_placeholders, ":" .. param_name) - params[param_name] = row[col] - param_idx = param_idx + 1 - end - table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")") - end - - local query = string.format( - "INSERT INTO %s (%s) VALUES %s", - table_name, - table.concat(cols, ", "), - table.concat(value_groups, ", ") - ) - return self:exec(query, params) - end - - error("connection:insert: invalid data format", 2) - end, - - update = function(self, table_name, data, where, where_params, ...) - if type(data) ~= "table" or next(data) == nil then - return 0 - end - - local sets = {} - local params = {} - local param_idx = 1 - - for col, val in pairs(data) do - local param_name = "p" .. param_idx - table.insert(sets, col .. " = :" .. param_name) - params[param_name] = val - param_idx = param_idx + 1 - end - - local query = string.format("UPDATE %s SET %s", table_name, table.concat(sets, ", ")) - - if where then - query = query .. " WHERE " .. where - if where_params then - local normalized = normalize_params(where_params, ...) - if type(normalized) == "table" then - for k, v in pairs(normalized) do - if type(k) == "string" then - params[k] = v - else - params["w" .. param_idx] = v - param_idx = param_idx + 1 - end - end - end - end - end - - return self:exec(query, params) - end, - - create_table = function(self, table_name, ...) - local column_definitions = {} - local index_definitions = {} - - for _, def_string in ipairs({...}) do - if type(def_string) == "string" then - local is_unique = false - local index_def = def_string - - if string.starts_with(def_string, "UNIQUE INDEX:") then - is_unique = true - index_def = string.trim(def_string:sub(14)) - elseif string.starts_with(def_string, "INDEX:") then - index_def = string.trim(def_string:sub(7)) - else - table.insert(column_definitions, def_string) - goto continue - end - - local paren_pos = index_def:find("%(") - if not paren_pos then goto continue end - - local index_name = string.trim(index_def:sub(1, paren_pos - 1)) - local columns_part = index_def:sub(paren_pos + 1):match("^(.-)%)%s*$") - if not columns_part then goto continue end - - local columns = table.map(string.split(columns_part, ","), string.trim) - - if #columns > 0 then - table.insert(index_definitions, { - name = index_name, - columns = columns, - unique = is_unique - }) - end - end - ::continue:: - end - - if #column_definitions == 0 then - error("connection:create_table: no column definitions specified for table " .. table_name, 2) - end - - local statements = {} - - table.insert(statements, string.format( - "CREATE TABLE IF NOT EXISTS %s (%s)", - table_name, - table.concat(column_definitions, ", ") - )) - - for _, idx in ipairs(index_definitions) do - local unique_prefix = idx.unique and "UNIQUE " or "" - table.insert(statements, string.format( - "CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)", - unique_prefix, - idx.name, - table_name, - table.concat(idx.columns, ", ") - )) - end - - return self:exec(table.concat(statements, ";\n")) - end, - - delete = function(self, table_name, where, params, ...) - local query = "DELETE FROM " .. table_name - if where then - query = query .. " WHERE " .. where - end - return self:exec(query, normalize_params(params, ...)) - end, - - exists = function(self, table_name, where, params, ...) - if type(table_name) ~= "string" then - error("connection:exists: table_name must be a string", 2) - end - - local query = "SELECT 1 FROM " .. table_name - if where then - query = query .. " WHERE " .. where - end - query = query .. " LIMIT 1" - - local results = self:query(query, normalize_params(params, ...)) - return #results > 0 - end, - - begin = function(self) - return self:exec("BEGIN TRANSACTION") - end, - - commit = function(self) - return self:exec("COMMIT") - end, - - rollback = function(self) - return self:exec("ROLLBACK") - end, - - transaction = function(self, callback) - self:begin() - local success, result = pcall(callback, self) - if success then - self:commit() - return result - else - self:rollback() - error(result, 2) - end - end - } -} - -function sqlite(db_name) - if type(db_name) ~= "string" then - error("sqlite: database name must be a string", 2) - end - - return setmetatable({ - db_name = db_name - }, connection_mt) -end diff --git a/runner/lua/string.lua b/runner/lua/string.lua deleted file mode 100644 index 5289ebc..0000000 --- a/runner/lua/string.lua +++ /dev/null @@ -1,195 +0,0 @@ --- string.lua - -local string_ext = {} - --- ====================================================================== --- STRING UTILITY FUNCTIONS --- ====================================================================== - --- Trim whitespace from both ends -function string_ext.trim(s) - if type(s) ~= "string" then return s end - return s:match("^%s*(.-)%s*$") -end - --- Split string by delimiter -function string_ext.split(s, delimiter) - if type(s) ~= "string" then return {} end - - delimiter = delimiter or "," - local result = {} - for match in (s..delimiter):gmatch("(.-)"..delimiter) do - table.insert(result, match) - end - return result -end - --- Check if string starts with prefix -function string_ext.starts_with(s, prefix) - if type(s) ~= "string" or type(prefix) ~= "string" then return false end - return s:sub(1, #prefix) == prefix -end - --- Check if string ends with suffix -function string_ext.ends_with(s, suffix) - if type(s) ~= "string" or type(suffix) ~= "string" then return false end - return suffix == "" or s:sub(-#suffix) == suffix -end - --- Left pad a string -function string_ext.pad_left(s, len, char) - if type(s) ~= "string" or type(len) ~= "number" then return s end - - char = char or " " - if #s >= len then return s end - - return string.rep(char:sub(1,1), len - #s) .. s -end - --- Right pad a string -function string_ext.pad_right(s, len, char) - if type(s) ~= "string" or type(len) ~= "number" then return s end - - char = char or " " - if #s >= len then return s end - - return s .. string.rep(char:sub(1,1), len - #s) -end - --- Center a string -function string_ext.center(s, width, char) - if type(s) ~= "string" or width <= #s then return s end - - char = char or " " - local pad_len = width - #s - local left_pad = math.floor(pad_len / 2) - local right_pad = pad_len - left_pad - - return string.rep(char:sub(1,1), left_pad) .. s .. string.rep(char:sub(1,1), right_pad) -end - --- Count occurrences of substring -function string_ext.count(s, substr) - if type(s) ~= "string" or type(substr) ~= "string" or #substr == 0 then return 0 end - - local count, pos = 0, 1 - while true do - pos = s:find(substr, pos, true) - if not pos then break end - count = count + 1 - pos = pos + 1 - end - return count -end - --- Capitalize first letter -function string_ext.capitalize(s) - if type(s) ~= "string" or #s == 0 then return s end - return s:sub(1,1):upper() .. s:sub(2) -end - --- Capitalize all words -function string_ext.title(s) - if type(s) ~= "string" then return s end - - return s:gsub("(%w)([%w]*)", function(first, rest) - return first:upper() .. rest:lower() - end) -end - --- Insert string at position -function string_ext.insert(s, pos, insert_str) - if type(s) ~= "string" or type(insert_str) ~= "string" then return s end - - pos = math.max(1, math.min(pos, #s + 1)) - return s:sub(1, pos - 1) .. insert_str .. s:sub(pos) -end - --- Remove substring -function string_ext.remove(s, start, length) - if type(s) ~= "string" then return s end - - length = length or 1 - if start < 1 or start > #s then return s end - - return s:sub(1, start - 1) .. s:sub(start + length) -end - --- Replace substring once -function string_ext.replace(s, old, new, n) - if type(s) ~= "string" or type(old) ~= "string" or #old == 0 then return s end - - new = new or "" - n = n or 1 - - return s:gsub(old:gsub("[%-%^%$%(%)%%%.%[%]%*%+%-%?]", "%%%1"), new, n) -end - --- Check if string contains substring -function string_ext.contains(s, substr) - if type(s) ~= "string" or type(substr) ~= "string" then return false end - return s:find(substr, 1, true) ~= nil -end - --- Escape pattern magic characters -function string_ext.escape_pattern(s) - if type(s) ~= "string" then return s end - return s:gsub("[%-%^%$%(%)%%%.%[%]%*%+%-%?]", "%%%1") -end - --- Wrap text at specified width -function string_ext.wrap(s, width, indent_first, indent_rest) - if type(s) ~= "string" or type(width) ~= "number" then return s end - - width = math.max(1, width) - indent_first = indent_first or "" - indent_rest = indent_rest or indent_first - - local result = {} - local line_prefix = indent_first - local pos = 1 - - while pos <= #s do - local line_width = width - #line_prefix - local end_pos = math.min(pos + line_width - 1, #s) - - if end_pos < #s then - local last_space = s:sub(pos, end_pos):match(".*%s()") - if last_space then - end_pos = pos + last_space - 2 - end - end - - table.insert(result, line_prefix .. s:sub(pos, end_pos)) - pos = end_pos + 1 - - -- Skip leading spaces on next line - while s:sub(pos, pos) == " " do - pos = pos + 1 - end - - line_prefix = indent_rest - end - - return table.concat(result, "\n") -end - --- Limit string length with ellipsis -function string_ext.truncate(s, length, ellipsis) - if type(s) ~= "string" then return s end - - ellipsis = ellipsis or "..." - if #s <= length then return s end - - return s:sub(1, length - #ellipsis) .. ellipsis -end - --- ====================================================================== --- INSTALL EXTENSIONS INTO STRING LIBRARY --- ====================================================================== - -for name, func in pairs(string) do - string_ext[name] = func -end - -return string_ext diff --git a/runner/lua/table.lua b/runner/lua/table.lua deleted file mode 100644 index 6bb5276..0000000 --- a/runner/lua/table.lua +++ /dev/null @@ -1,1090 +0,0 @@ --- table.lua - -local table_ext = {} - --- ====================================================================== --- SET OPERATIONS --- ====================================================================== - --- Remove duplicate values (like array_unique) -function table_ext.unique(t) - if type(t) ~= "table" then return {} end - - local seen = {} - local result = {} - - for _, v in ipairs(t) do - if not seen[v] then - seen[v] = true - table.insert(result, v) - end - end - - return result -end - --- Return items in first table that are present in all other tables (like array_intersect) -function table_ext.intersect(t1, ...) - if type(t1) ~= "table" then return {} end - - local args = {...} - local result = {} - - -- Convert all tables to sets for O(1) lookups - local sets = {} - for i, t in ipairs(args) do - if type(t) ~= "table" then - return {} - end - - sets[i] = {} - for _, v in ipairs(t) do - sets[i][v] = true - end - end - - -- Check each element in t1 against all other tables - for _, v in ipairs(t1) do - local present_in_all = true - - for i = 1, #args do - if not sets[i][v] then - present_in_all = false - break - end - end - - if present_in_all then - table.insert(result, v) - end - end - - return result -end - --- Return items in first table that are not present in other tables (like array_diff) -function table_ext.diff(t1, ...) - if type(t1) ~= "table" then return {} end - - local args = {...} - local result = {} - - -- Build unified set of elements from other tables - local others = {} - for _, t in ipairs(args) do - if type(t) == "table" then - for _, v in ipairs(t) do - others[v] = true - end - end - end - - -- Add elements from t1 that aren't in other tables - for _, v in ipairs(t1) do - if not others[v] then - table.insert(result, v) - end - end - - return result -end - --- ====================================================================== --- SEARCH AND FILTERING --- ====================================================================== - --- Check if value exists in table (like in_array) -function table_ext.contains(t, value) - if type(t) ~= "table" then return false end - - for _, v in ipairs(t) do - if v == value then - return true - end - end - - return false -end - --- Find key for a value (like array_search) -function table_ext.find(t, value) - if type(t) ~= "table" then return nil end - - for k, v in pairs(t) do - if v == value then - return k - end - end - - return nil -end - --- Filter table elements (like array_filter) -function table_ext.filter(t, func) - if type(t) ~= "table" or type(func) ~= "function" then return {} end - - local result = {} - - for k, v in pairs(t) do - if func(v, k) then - if type(k) == "number" and k % 1 == 0 and k > 0 then - -- For array-like tables, maintain numerical indices - table.insert(result, v) - else - -- For associative tables, preserve the key - result[k] = v - end - end - end - - return result -end - --- ====================================================================== --- TRANSFORMATION FUNCTIONS --- ====================================================================== - --- Apply a function to all values (like array_map) -function table_ext.map(t, func) - if type(t) ~= "table" or type(func) ~= "function" then return {} end - - local result = {} - - for k, v in pairs(t) do - if type(k) == "number" and k % 1 == 0 and k > 0 then - -- For array-like tables, maintain numerical indices - table.insert(result, func(v, k)) - else - -- For associative tables, preserve the key - result[k] = func(v, k) - end - end - - return result -end - --- Reduce a table to a single value (like array_reduce) -function table_ext.reduce(t, func, initial) - if type(t) ~= "table" or type(func) ~= "function" then - return initial - end - - local result = initial - - for k, v in pairs(t) do - if result == nil then - result = v - else - result = func(result, v, k) - end - end - - return result -end - --- ====================================================================== --- ADVANCED OPERATIONS --- ====================================================================== - --- Split table into chunks (like array_chunk) -function table_ext.chunk(t, size) - if type(t) ~= "table" or type(size) ~= "number" or size <= 0 then - return {} - end - - local result = {} - local chunk = {} - local count = 0 - - for _, v in ipairs(t) do - count = count + 1 - chunk[count] = v - - if count == size then - table.insert(result, chunk) - chunk = {} - count = 0 - end - end - - -- Add the last chunk if it has any elements - if count > 0 then - table.insert(result, chunk) - end - - return result -end - --- Extract a column from a table of tables (like array_column) -function table_ext.column(t, column_key, index_key) - if type(t) ~= "table" or column_key == nil then return {} end - - local result = {} - - for _, row in ipairs(t) do - if type(row) == "table" and row[column_key] ~= nil then - if index_key ~= nil and row[index_key] ~= nil then - result[row[index_key]] = row[column_key] - else - table.insert(result, row[column_key]) - end - end - end - - return result -end - --- Merge tables (like array_merge, but preserves keys) -function table_ext.merge(...) - local result = {} - - for _, t in ipairs({...}) do - if type(t) == "table" then - for k, v in pairs(t) do - if type(k) == "number" and k % 1 == 0 and k > 0 then - -- For array-like tables, append values - table.insert(result, v) - else - -- For associative tables, overwrite with latest value - result[k] = v - end - end - end - end - - return result -end - --- ====================================================================== --- KEY MANIPULATION --- ====================================================================== - --- Exchange keys with values (like array_flip) -function table_ext.flip(t) - if type(t) ~= "table" then return {} end - - local result = {} - - for k, v in pairs(t) do - if type(v) == "string" or type(v) == "number" then - result[v] = k - end - end - - return result -end - --- Get all keys from a table (like array_keys) -function table_ext.keys(t) - if type(t) ~= "table" then return {} end - - local result = {} - - for k, _ in pairs(t) do - table.insert(result, k) - end - - return result -end - --- Get all values from a table (like array_values) -function table_ext.values(t) - if type(t) ~= "table" then return {} end - - local result = {} - - for _, v in pairs(t) do - table.insert(result, v) - end - - return result -end - --- ====================================================================== --- STATISTICAL FUNCTIONS --- ====================================================================== - --- Sum all values (like array_sum) -function table_ext.sum(t) - if type(t) ~= "table" then return 0 end - - local sum = 0 - - for _, v in pairs(t) do - if type(v) == "number" then - sum = sum + v - end - end - - return sum -end - --- Multiply all values (like array_product) -function table_ext.product(t) - if type(t) ~= "table" then return 0 end - - local product = 1 - local has_number = false - - for _, v in pairs(t) do - if type(v) == "number" then - product = product * v - has_number = true - end - end - - return has_number and product or 0 -end - --- Count value occurrences (like array_count_values) -function table_ext.count_values(t) - if type(t) ~= "table" then return {} end - - local result = {} - - for _, v in pairs(t) do - if type(v) == "string" or type(v) == "number" then - result[v] = (result[v] or 0) + 1 - end - end - - return result -end - --- ====================================================================== --- CREATION HELPERS --- ====================================================================== - --- Create a table with a range of values (like range) -function table_ext.range(start, stop, step) - if type(start) ~= "number" then return {} end - - step = step or 1 - - local result = {} - - if not stop then - stop = start - start = 1 - end - - if (step > 0 and start > stop) or (step < 0 and start < stop) then - return {} - end - - local i = start - while (step > 0 and i <= stop) or (step < 0 and i >= stop) do - table.insert(result, i) - i = i + step - end - - return result -end - --- Fill a table with a value (like array_fill) -function table_ext.fill(start_index, count, value) - if type(start_index) ~= "number" or type(count) ~= "number" or count < 0 then - return {} - end - - local result = {} - - for i = 0, count - 1 do - result[start_index + i] = value - end - - return result -end - --- ====================================================================== --- ADDITIONAL USEFUL FUNCTIONS --- ====================================================================== - --- Reverse a table (array part only) -function table_ext.reverse(t) - if type(t) ~= "table" then return {} end - - local result = {} - local count = #t - - for i = count, 1, -1 do - table.insert(result, t[i]) - end - - return result -end - --- Get the max value in a table -function table_ext.max(t) - if type(t) ~= "table" or #t == 0 then return nil end - - local max = t[1] - - for i = 2, #t do - if t[i] > max then - max = t[i] - end - end - - return max -end - --- Get the min value in a table -function table_ext.min(t) - if type(t) ~= "table" or #t == 0 then return nil end - - local min = t[1] - - for i = 2, #t do - if t[i] < min then - min = t[i] - end - end - - return min -end - --- Check if all elements satisfy a condition -function table_ext.all(t, func) - if type(t) ~= "table" or type(func) ~= "function" then return false end - - for k, v in pairs(t) do - if not func(v, k) then - return false - end - end - - return true -end - --- Check if any element satisfies a condition -function table_ext.any(t, func) - if type(t) ~= "table" or type(func) ~= "function" then return false end - - for k, v in pairs(t) do - if func(v, k) then - return true - end - end - - return false -end - --- ====================================================================== --- TABLE UTILITIES --- ====================================================================== - --- Check if table is empty -function table_ext.is_empty(t) - if type(t) ~= "table" then return true end - return next(t) == nil -end - --- Get table length (works for both array and hash parts) -function table_ext.size(t) - if type(t) ~= "table" then return 0 end - - local count = 0 - for _ in pairs(t) do - count = count + 1 - end - - return count -end - --- Get a slice of a table -function table_ext.slice(t, start, stop) - if type(t) ~= "table" then return {} end - - local len = #t - start = start or 1 - stop = stop or len - - -- Convert negative indices - if start < 0 then start = len + start + 1 end - if stop < 0 then stop = len + stop + 1 end - - -- Ensure bounds - start = math.max(1, math.min(start, len + 1)) - stop = math.max(0, math.min(stop, len)) - - local result = {} - for i = start, stop do - table.insert(result, t[i]) - end - - return result -end - --- ====================================================================== --- SORTING FUNCTIONS --- ====================================================================== - --- Sort array values (like sort) -function table_ext.sort(t) - if type(t) ~= "table" then return t end - table.sort(t) - return t -end - --- Sort array values in reverse order (like rsort) -function table_ext.rsort(t) - if type(t) ~= "table" then return t end - table.sort(t, function(a, b) return a > b end) - return t -end - --- Sort and maintain index association (like asort) -function table_ext.asort(t) - if type(t) ~= "table" then return t end - - local keys, result = {}, {} - for k in pairs(t) do - table.insert(keys, k) - end - - table.sort(keys, function(a, b) - return t[a] < t[b] - end) - - for _, k in ipairs(keys) do - result[k] = t[k] - end - - return result -end - --- Sort in reverse order and maintain index association (like arsort) -function table_ext.arsort(t) - if type(t) ~= "table" then return t end - - local keys, result = {}, {} - for k in pairs(t) do - table.insert(keys, k) - end - - table.sort(keys, function(a, b) - return t[a] > t[b] - end) - - for _, k in ipairs(keys) do - result[k] = t[k] - end - - return result -end - --- Sort by keys (like ksort) -function table_ext.ksort(t) - if type(t) ~= "table" then return t end - - local keys, result = {}, {} - for k in pairs(t) do - table.insert(keys, k) - end - - table.sort(keys) - - for _, k in ipairs(keys) do - result[k] = t[k] - end - - return result -end - --- Sort by keys in reverse order (like krsort) -function table_ext.krsort(t) - if type(t) ~= "table" then return t end - - local keys, result = {}, {} - for k in pairs(t) do - table.insert(keys, k) - end - - table.sort(keys, function(a, b) return a > b end) - - for _, k in ipairs(keys) do - result[k] = t[k] - end - - return result -end - --- Sort using custom comparison function (like usort) -function table_ext.usort(t, compare_func) - if type(t) ~= "table" or type(compare_func) ~= "function" then return t end - - table.sort(t, compare_func) - return t -end - --- Sort maintaining keys using custom comparison function (like uasort) -function table_ext.uasort(t, compare_func) - if type(t) ~= "table" or type(compare_func) ~= "function" then return t end - - local keys, result = {}, {} - for k in pairs(t) do - table.insert(keys, k) - end - - table.sort(keys, function(a, b) - return compare_func(t[a], t[b]) - end) - - for _, k in ipairs(keys) do - result[k] = t[k] - end - - return result -end - --- Sort by keys using custom comparison function (like uksort) -function table_ext.uksort(t, compare_func) - if type(t) ~= "table" or type(compare_func) ~= "function" then return t end - - local keys, result = {}, {} - for k in pairs(t) do - table.insert(keys, k) - end - - table.sort(keys, compare_func) - - for _, k in ipairs(keys) do - result[k] = t[k] - end - - return result -end - --- Natural order sort (like natsort) -function table_ext.natsort(t) - if type(t) ~= "table" then return t end - - local function natural_compare(a, b) - local function get_chunks(s) - if type(s) ~= "string" then s = tostring(s) end - local chunks = {} - for num, alpha in s:gmatch("(%d+)([^%d]*)") do - table.insert(chunks, {n=true, val=tonumber(num)}) - if alpha ~= "" then - table.insert(chunks, {n=false, val=alpha}) - end - end - return chunks - end - - local a_chunks = get_chunks(a) - local b_chunks = get_chunks(b) - - for i = 1, math.min(#a_chunks, #b_chunks) do - if a_chunks[i].n ~= b_chunks[i].n then - return a_chunks[i].n -- numbers come before strings - elseif a_chunks[i].val ~= b_chunks[i].val then - if a_chunks[i].n then - return a_chunks[i].val < b_chunks[i].val - else - return a_chunks[i].val < b_chunks[i].val - end - end - end - - return #a_chunks < #b_chunks - end - - table.sort(t, natural_compare) - return t -end - --- Natural case-insensitive sort (like natcasesort) -function table_ext.natcasesort(t) - if type(t) ~= "table" then return t end - - local function case_insensitive_natural_compare(a, b) - if type(a) == "string" and type(b) == "string" then - return table_ext.natural_compare(a:lower(), b:lower()) - else - return table_ext.natural_compare(a, b) - end - end - - return table_ext.usort(t, case_insensitive_natural_compare) -end - --- ====================================================================== --- ARRAY MODIFICATION FUNCTIONS --- ====================================================================== - --- Push one or more elements onto the end (like array_push) -function table_ext.push(t, ...) - if type(t) ~= "table" then return 0 end - - local count = 0 - for _, v in ipairs({...}) do - table.insert(t, v) - count = count + 1 - end - - return count -end - --- Pop the element off the end (like array_pop) -function table_ext.pop(t) - if type(t) ~= "table" or #t == 0 then return nil end - - local value = t[#t] - t[#t] = nil - return value -end - --- Shift an element off the beginning (like array_shift) -function table_ext.shift(t) - if type(t) ~= "table" or #t == 0 then return nil end - - local value = t[1] - table.remove(t, 1) - return value -end - --- Prepend elements to the beginning (like array_unshift) -function table_ext.unshift(t, ...) - if type(t) ~= "table" then return 0 end - - local args = {...} - for i = #args, 1, -1 do - table.insert(t, 1, args[i]) - end - - return #t -end - --- Pad array to specified length (like array_pad) -function table_ext.pad(t, size, value) - if type(t) ~= "table" then return {} end - - local result = table_ext.deep_copy(t) - local current_size = #result - - if size == current_size then return result end - - if size > current_size then - -- Pad to the right - for i = current_size + 1, size do - result[i] = value - end - else - -- Pad to the left (negative size) - local abs_size = math.abs(size) - if abs_size < current_size then - local temp = {} - for i = 1, abs_size do - if i <= abs_size - current_size then - temp[i] = value - else - temp[i] = result[i - (abs_size - current_size)] - end - end - result = temp - else - -- Fill completely with padding value - result = {} - for i = 1, abs_size do - result[i] = value - end - end - end - - return result -end - --- Remove a portion and replace it (like array_splice) -function table_ext.splice(t, offset, length, ...) - if type(t) ~= "table" then return {} end - - local result = table_ext.deep_copy(t) - local size = #result - - -- Handle negative offset - if offset < 0 then - offset = size + offset - end - - -- Ensure offset is valid - offset = math.max(1, math.min(offset, size + 1)) - - -- Handle negative or nil length - if length == nil then - length = size - offset + 1 - elseif length < 0 then - length = math.max(0, size - offset + length + 1) - end - - -- Extract removed portion - local removed = {} - for i = offset, offset + length - 1 do - if i <= size then - table.insert(removed, result[i]) - end - end - - -- Remove portion from original - for i = 1, length do - table.remove(result, offset) - end - - -- Insert replacement values - local replacements = {...} - for i = #replacements, 1, -1 do - table.insert(result, offset, replacements[i]) - end - - return removed, result -end - --- Randomize array order (like shuffle) -function table_ext.shuffle(t) - if type(t) ~= "table" then return t end - - local result = table_ext.deep_copy(t) - local size = #result - - for i = size, 2, -1 do - local j = math.random(i) - result[i], result[j] = result[j], result[i] - end - - return result -end - --- Pick random keys from array (like array_rand) -function table_ext.rand(t, num_keys) - if type(t) ~= "table" then return nil end - - local size = #t - if size == 0 then return nil end - - num_keys = num_keys or 1 - num_keys = math.min(num_keys, size) - - if num_keys <= 0 then return nil end - - if num_keys == 1 then - return math.random(size) - else - local keys = {} - local result = {} - - -- Create a list of all possible keys - for i = 1, size do - keys[i] = i - end - - -- Select random keys - for i = 1, num_keys do - local j = math.random(#keys) - table.insert(result, keys[j]) - table.remove(keys, j) - end - - table.sort(result) - return result - end -end - --- ====================================================================== --- ARRAY INSPECTION FUNCTIONS --- ====================================================================== - --- Check if key exists (like array_key_exists) -function table_ext.key_exists(key, t) - if type(t) ~= "table" then return false end - return t[key] ~= nil -end - --- Get the first key (like array_key_first) -function table_ext.key_first(t) - if type(t) ~= "table" then return nil end - - -- For array-like tables - if #t > 0 then return 1 end - - -- For associative tables - local first_key = nil - for k in pairs(t) do - first_key = k - break - end - - return first_key -end - --- Get the last key (like array_key_last) -function table_ext.key_last(t) - if type(t) ~= "table" then return nil end - - -- For array-like tables - if #t > 0 then return #t end - - -- For associative tables (no guaranteed order, return any key) - local last_key = nil - for k in pairs(t) do - last_key = k - end - - return last_key -end - --- Check if table is a list (like array_is_list) -function table_ext.is_list(t) - if type(t) ~= "table" then return false end - - local count = 0 - for k in pairs(t) do - count = count + 1 - if type(k) ~= "number" or k <= 0 or math.floor(k) ~= k or k > count then - return false - end - end - - return true -end - --- ====================================================================== --- OTHER IMPORTANT FUNCTIONS --- ====================================================================== - --- Create array with keys from one array, values from another (like array_combine) -function table_ext.combine(keys, values) - if type(keys) ~= "table" or type(values) ~= "table" then return {} end - - local result = {} - local key_count = #keys - local value_count = #values - - for i = 1, math.min(key_count, value_count) do - result[keys[i]] = values[i] - end - - return result -end - --- Replace elements from one array into another (like array_replace) -function table_ext.replace(t, ...) - if type(t) ~= "table" then return {} end - - local result = table_ext.deep_copy(t) - - for _, replacement in ipairs({...}) do - if type(replacement) == "table" then - for k, v in pairs(replacement) do - result[k] = v - end - end - end - - return result -end - --- Replace elements recursively (like array_replace_recursive) -function table_ext.replace_recursive(t, ...) - if type(t) ~= "table" then return {} end - - local result = table_ext.deep_copy(t) - - for _, replacement in ipairs({...}) do - if type(replacement) == "table" then - for k, v in pairs(replacement) do - if type(v) == "table" and type(result[k]) == "table" then - result[k] = table_ext.replace_recursive(result[k], v) - else - result[k] = v - end - end - end - end - - return result -end - --- Apply function to each element (like array_walk) -function table_ext.walk(t, callback, user_data) - if type(t) ~= "table" or type(callback) ~= "function" then return t end - - for k, v in pairs(t) do - t[k] = callback(v, k, user_data) - end - - return t -end - --- Apply function recursively (like array_walk_recursive) -function table_ext.walk_recursive(t, callback, user_data) - if type(t) ~= "table" or type(callback) ~= "function" then return t end - - for k, v in pairs(t) do - if type(v) == "table" then - table_ext.walk_recursive(v, callback, user_data) - else - t[k] = callback(v, k, user_data) - end - end - - return t -end - --- Sort multiple arrays (simplified array_multisort) -function table_ext.multisort(...) - local args = {...} - if #args == 0 then return end - - -- First argument is the main table - local main = args[1] - if type(main) ~= "table" then return end - - -- Create a table of indices - local indices = {} - for i = 1, #main do - indices[i] = i - end - - -- Sort the indices based on the arrays - table.sort(indices, function(a, b) - for i = 1, #args do - local arr = args[i] - if type(arr) == "table" then - if arr[a] ~= arr[b] then - return arr[a] < arr[b] - end - end - end - return a < b - end) - - -- Reorder all arrays based on sorted indices - for i = 1, #args do - local arr = args[i] - if type(arr) == "table" then - local temp = table_ext.deep_copy(arr) - for j = 1, #indices do - arr[j] = temp[indices[j]] - end - end - end -end - --- Efficient deep copy function -function table_ext.deep_copy(obj) - if type(obj) ~= 'table' then return obj end - local res = {} - for k, v in pairs(obj) do res[k] = table_ext.deep_copy(v) end - return res -end - --- ====================================================================== --- INSTALL EXTENSIONS INTO TABLE LIBRARY --- ====================================================================== - -for name, func in pairs(table) do - table_ext[name] = func -end - -return table_ext diff --git a/runner/lua/time.lua b/runner/lua/time.lua deleted file mode 100644 index 2e37a86..0000000 --- a/runner/lua/time.lua +++ /dev/null @@ -1,128 +0,0 @@ --- time.lua - -local ffi = require('ffi') -local is_windows = (ffi.os == "Windows") - --- Define C structures and functions based on platform -if is_windows then - ffi.cdef[[ - typedef struct { - int64_t QuadPart; - } LARGE_INTEGER; - int QueryPerformanceCounter(LARGE_INTEGER* lpPerformanceCount); - int QueryPerformanceFrequency(LARGE_INTEGER* lpFrequency); - ]] -else - ffi.cdef[[ - typedef long time_t; - typedef struct timeval { - long tv_sec; - long tv_usec; - } timeval; - int gettimeofday(struct timeval* tv, void* tz); - time_t time(time_t* t); - ]] -end - -local time = {} -local has_initialized = false -local start_time, timer_freq - --- Initialize timing system based on platform -local function init() - if has_initialized then return end - - if ffi.os == "Windows" then - local frequency = ffi.new("LARGE_INTEGER") - ffi.C.QueryPerformanceFrequency(frequency) - timer_freq = tonumber(frequency.QuadPart) - - local counter = ffi.new("LARGE_INTEGER") - ffi.C.QueryPerformanceCounter(counter) - start_time = tonumber(counter.QuadPart) - else - -- Nothing special needed for Unix platform init - start_time = ffi.C.time(nil) - end - - has_initialized = true -end - --- PHP-compatible microtime implementation -function time.microtime(get_as_float) - init() - - if ffi.os == "Windows" then - local counter = ffi.new("LARGE_INTEGER") - ffi.C.QueryPerformanceCounter(counter) - local now = tonumber(counter.QuadPart) - local seconds = math.floor((now - start_time) / timer_freq) - local microseconds = ((now - start_time) % timer_freq) * 1000000 / timer_freq - - if get_as_float then - return seconds + microseconds / 1000000 - else - return string.format("0.%06d %d", microseconds, seconds) - end - else - local tv = ffi.new("struct timeval") - ffi.C.gettimeofday(tv, nil) - - if get_as_float then - return tonumber(tv.tv_sec) + tonumber(tv.tv_usec) / 1000000 - else - return string.format("0.%06d %d", tv.tv_usec, tv.tv_sec) - end - end -end - --- High-precision monotonic timer (returns seconds with microsecond precision) -function time.monotonic() - init() - - if ffi.os == "Windows" then - local counter = ffi.new("LARGE_INTEGER") - ffi.C.QueryPerformanceCounter(counter) - local now = tonumber(counter.QuadPart) - return (now - start_time) / timer_freq - else - local tv = ffi.new("struct timeval") - ffi.C.gettimeofday(tv, nil) - return tonumber(tv.tv_sec) - start_time + tonumber(tv.tv_usec) / 1000000 - end -end - --- Benchmark function that measures execution time -function time.benchmark(func, iterations, warmup) - iterations = iterations or 1000 - warmup = warmup or 10 - - -- Warmup - for i=1, warmup do func() end - - local start = time.microtime(true) - for i=1, iterations do - func() - end - local finish = time.microtime(true) - - local elapsed = (finish - start) * 1000000 -- Convert to microseconds - return elapsed / iterations -end - --- Simple sleep function using coroutine yielding -function time.sleep(seconds) - if type(seconds) ~= "number" or seconds <= 0 then - return - end - - local start = time.monotonic() - while time.monotonic() - start < seconds do - -- Use coroutine.yield to avoid consuming CPU - coroutine.yield() - end -end - -_G.microtime = time.microtime - -return time diff --git a/runner/lua/timestamp.lua b/runner/lua/timestamp.lua deleted file mode 100644 index f62b9ec..0000000 --- a/runner/lua/timestamp.lua +++ /dev/null @@ -1,93 +0,0 @@ --- timestamp.lua -local timestamp = {} - --- Standard format presets using Lua format codes -local FORMATS = { - iso = "%Y-%m-%dT%H:%M:%SZ", - datetime = "%Y-%m-%d %H:%M:%S", - us_date = "%m/%d/%Y", - us_datetime = "%m/%d/%Y %I:%M:%S %p", - date = "%Y-%m-%d", - time = "%H:%M:%S", - time12 = "%I:%M:%S %p", - readable = "%B %d, %Y %I:%M:%S %p", - compact = "%Y%m%d_%H%M%S" -} - --- Parse input to unix timestamp and microseconds -local function parse_input(input) - local unix_time, micros = 0, 0 - - if type(input) == "string" then - local frac, secs = input:match("^(0%.%d+)%s+(%d+)$") - if frac and secs then - unix_time = tonumber(secs) - micros = math.floor((tonumber(frac) * 1000000) + 0.5) - else - unix_time = tonumber(input) or 0 - end - elseif type(input) == "number" then - unix_time = math.floor(input) - micros = math.floor(((input - unix_time) * 1000000) + 0.5) - end - - return unix_time, micros -end - --- Remove leading zeros from number string -local function no_leading_zero(s) - return s:gsub("^0+", "") or "0" -end - --- Main format function -function timestamp.format(input, fmt) - fmt = fmt or "datetime" - local format_str = FORMATS[fmt] or fmt - local unix_time, micros = parse_input(input) - local result = os.date(format_str, unix_time) - - -- Handle microseconds if format contains dot - if format_str:find("%.") then - result = result .. string.format(".%06d", micros) - end - - return result -end - --- US date/time with no leading zeros -function timestamp.us_datetime_no_zero(input) - local unix_time, micros = parse_input(input) - local month = no_leading_zero(os.date("%m", unix_time)) - local day = no_leading_zero(os.date("%d", unix_time)) - local year = os.date("%Y", unix_time) - local hour = no_leading_zero(os.date("%I", unix_time)) - local min = os.date("%M", unix_time) - local sec = os.date("%S", unix_time) - local ampm = os.date("%p", unix_time) - - return string.format("%s/%s/%s %s:%s:%s %s", month, day, year, hour, min, sec, ampm) -end - --- Quick preset functions -function timestamp.iso(input) return timestamp.format(input, "iso") end -function timestamp.datetime(input) return timestamp.format(input, "datetime") end -function timestamp.us_date(input) return timestamp.format(input, "us_date") end -function timestamp.us_datetime(input) return timestamp.us_datetime_no_zero(input) end -function timestamp.date(input) return timestamp.format(input, "date") end -function timestamp.time(input) return timestamp.format(input, "time") end -function timestamp.time12(input) return timestamp.format(input, "time12") end -function timestamp.readable(input) return timestamp.format(input, "readable") end - --- Microsecond precision variants -function timestamp.datetime_micro(input) - return timestamp.format(input, "%Y-%m-%d %H:%M:%S.") .. string.format("%06d", select(2, parse_input(input))) -end - -function timestamp.iso_micro(input) - return timestamp.format(input, "%Y-%m-%dT%H:%M:%S.") .. string.format("%06dZ", select(2, parse_input(input))) -end - --- Register global convenience function -_G.format_time = timestamp.format - -return timestamp diff --git a/runner/lua/util.lua b/runner/lua/util.lua deleted file mode 100644 index 3874ca4..0000000 --- a/runner/lua/util.lua +++ /dev/null @@ -1,323 +0,0 @@ --- util.lua - --- ====================================================================== --- CORE UTILITY FUNCTIONS --- ====================================================================== - --- Generate a random token -function generate_token(length) - return __generate_token(length or 32) -end - --- ====================================================================== --- HTML ENTITY FUNCTIONS --- ====================================================================== - --- Convert special characters to HTML entities (like htmlspecialchars) -function html_special_chars(str) - if type(str) ~= "string" then - return str - end - - return __html_special_chars(str) -end - --- Convert all applicable characters to HTML entities (like htmlentities) -function html_entities(str) - if type(str) ~= "string" then - return str - end - - return __html_entities(str) -end - --- Convert HTML entities back to characters (simple version) -function html_entity_decode(str) - if type(str) ~= "string" then - return str - end - - str = str:gsub("<", "<") - str = str:gsub(">", ">") - str = str:gsub(""", '"') - str = str:gsub("'", "'") - str = str:gsub("&", "&") - - return str -end - --- Convert newlines to
tags -function nl2br(str) - if type(str) ~= "string" then - return str - end - - return str:gsub("\r\n", "
"):gsub("\n", "
"):gsub("\r", "
") -end - --- ====================================================================== --- URL FUNCTIONS --- ====================================================================== - --- URL encode a string -function url_encode(str) - if type(str) ~= "string" then - return str - end - - str = str:gsub("\n", "\r\n") - str = str:gsub("([^%w %-%_%.%~])", function(c) - return string.format("%%%02X", string.byte(c)) - end) - str = str:gsub(" ", "+") - return str -end - --- URL decode a string -function url_decode(str) - if type(str) ~= "string" then - return str - end - - str = str:gsub("+", " ") - str = str:gsub("%%(%x%x)", function(h) - return string.char(tonumber(h, 16)) - end) - return str -end - --- ====================================================================== --- VALIDATION FUNCTIONS --- ====================================================================== - --- Email validation -function is_email(str) - if type(str) ~= "string" then - return false - end - - -- Simple email validation pattern - local pattern = "^[%w%.%%%+%-]+@[%w%.%%%+%-]+%.%w%w%w?%w?$" - return str:match(pattern) ~= nil -end - --- URL validation -function is_url(str) - if type(str) ~= "string" then - return false - end - - -- Simple URL validation - local pattern = "^https?://[%w-_%.%?%.:/%+=&%%]+$" - return str:match(pattern) ~= nil -end - --- IP address validation (IPv4) -function is_ipv4(str) - if type(str) ~= "string" then - return false - end - - local pattern = "^(%d%d?%d?)%.(%d%d?%d?)%.(%d%d?%d?)%.(%d%d?%d?)$" - local a, b, c, d = str:match(pattern) - - if not (a and b and c and d) then - return false - end - - a, b, c, d = tonumber(a), tonumber(b), tonumber(c), tonumber(d) - return a <= 255 and b <= 255 and c <= 255 and d <= 255 -end - --- Integer validation -function is_int(str) - if type(str) == "number" then - return math.floor(str) == str - elseif type(str) ~= "string" then - return false - end - - return str:match("^-?%d+$") ~= nil -end - --- Float validation -function is_float(str) - if type(str) == "number" then - return true - elseif type(str) ~= "string" then - return false - end - - return str:match("^-?%d+%.?%d*$") ~= nil -end - --- Boolean validation -function is_bool(value) - if type(value) == "boolean" then - return true - elseif type(value) ~= "string" and type(value) ~= "number" then - return false - end - - local v = type(value) == "string" and value:lower() or value - return v == "1" or v == "true" or v == "on" or v == "yes" or - v == "0" or v == "false" or v == "off" or v == "no" or - v == 1 or v == 0 -end - --- Convert to boolean -function to_bool(value) - if type(value) == "boolean" then - return value - elseif type(value) ~= "string" and type(value) ~= "number" then - return false - end - - local v = type(value) == "string" and value:lower() or value - return v == "1" or v == "true" or v == "on" or v == "yes" or v == 1 -end - --- Sanitize string (simple version) -function sanitize_string(str) - if type(str) ~= "string" then - return "" - end - - return html_special_chars(str) -end - --- Sanitize to integer -function sanitize_int(value) - if type(value) ~= "string" and type(value) ~= "number" then - return 0 - end - - value = tostring(value) - local result = value:match("^-?%d+") - return result and tonumber(result) or 0 -end - --- Sanitize to float -function sanitize_float(value) - if type(value) ~= "string" and type(value) ~= "number" then - return 0 - end - - value = tostring(value) - local result = value:match("^-?%d+%.?%d*") - return result and tonumber(result) or 0 -end - --- Sanitize URL -function sanitize_url(str) - if type(str) ~= "string" then - return "" - end - - -- Basic sanitization by removing control characters - str = str:gsub("[\000-\031]", "") - - -- Make sure it's a valid URL - if is_url(str) then - return str - end - - -- Try to prepend http:// if it's missing - if not str:match("^https?://") and is_url("http://" .. str) then - return "http://" .. str - end - - return "" -end - --- Sanitize email -function sanitize_email(str) - if type(str) ~= "string" then - return "" - end - - -- Remove all characters except common email characters - str = str:gsub("[^%a%d%!%#%$%%%&%'%*%+%-%/%=%?%^%_%`%{%|%}%~%@%.%[%]]", "") - - -- Return only if it's a valid email - if is_email(str) then - return str - end - - return "" -end - --- ====================================================================== --- SECURITY FUNCTIONS --- ====================================================================== - --- Basic XSS prevention -function xss_clean(str) - if type(str) ~= "string" then - return str - end - - -- Convert problematic characters to entities - local result = html_special_chars(str) - - -- Remove JavaScript event handlers - result = result:gsub("on%w+%s*=", "") - - -- Remove JavaScript protocol - result = result:gsub("javascript:", "") - - -- Remove CSS expression - result = result:gsub("expression%s*%(", "") - - return result -end - --- Base64 encode -function base64_encode(str) - if type(str) ~= "string" then - return str - end - - return __base64_encode(str) -end - --- Base64 decode -function base64_decode(str) - if type(str) ~= "string" then - return str - end - - return __base64_decode(str) -end - --- ====================================================================== --- PASSWORD FUNCTIONS --- ====================================================================== - --- Hash a password using Argon2id --- Options: --- memory: Amount of memory to use in KB (default: 128MB) --- iterations: Number of iterations (default: 4) --- parallelism: Number of threads (default: 4) --- salt_length: Length of salt in bytes (default: 16) --- key_length: Length of the derived key in bytes (default: 32) -function password_hash(plain_password, options) - if type(plain_password) ~= "string" then - error("password_hash: expected string password", 2) - end - - return __password_hash(plain_password, options) -end - --- Verify a password against a hash -function password_verify(plain_password, hash_string) - if type(plain_password) ~= "string" then - error("password_verify: expected string password", 2) - end - - if type(hash_string) ~= "string" then - error("password_verify: expected string hash", 2) - end - - return __password_verify(plain_password, hash_string) -end diff --git a/runner/lualibs/crypto.go b/runner/lualibs/crypto.go deleted file mode 100644 index aa858d6..0000000 --- a/runner/lualibs/crypto.go +++ /dev/null @@ -1,463 +0,0 @@ -package lualibs - -import ( - "Moonshark/logger" - "crypto/hmac" - "crypto/md5" - "crypto/rand" - "crypto/sha1" - "crypto/sha256" - "crypto/sha512" - "encoding/base64" - "encoding/binary" - "encoding/hex" - "fmt" - "hash" - "math" - mrand "math/rand/v2" - "sync" - "time" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -var ( - // Map to store state-specific RNGs - stateRngs = make(map[*luajit.State]*mrand.PCG) - stateRngsMu sync.Mutex -) - -// RegisterCryptoFunctions registers all crypto functions with the Lua state -func RegisterCryptoFunctions(state *luajit.State) error { - // Create a state-specific RNG - stateRngsMu.Lock() - stateRngs[state] = mrand.NewPCG(uint64(time.Now().UnixNano()), uint64(time.Now().UnixNano()>>32)) - stateRngsMu.Unlock() - - if err := state.RegisterGoFunction("__generate_token", generateToken); err != nil { - return err - } - - // Register hash functions - if err := state.RegisterGoFunction("__crypto_hash", cryptoHash); err != nil { - return err - } - - // Register HMAC functions - if err := state.RegisterGoFunction("__crypto_hmac", cryptoHmac); err != nil { - return err - } - - // Register UUID generation - if err := state.RegisterGoFunction("__crypto_uuid", cryptoUuid); err != nil { - return err - } - - // Register random functions - if err := state.RegisterGoFunction("__crypto_random", cryptoRandom); err != nil { - return err - } - if err := state.RegisterGoFunction("__crypto_random_bytes", cryptoRandomBytes); err != nil { - return err - } - if err := state.RegisterGoFunction("__crypto_random_int", cryptoRandomInt); err != nil { - return err - } - if err := state.RegisterGoFunction("__crypto_random_seed", cryptoRandomSeed); err != nil { - return err - } - - // Override Lua's math.random - if err := OverrideLuaRandom(state); err != nil { - return err - } - - return nil -} - -// CleanupCrypto cleans up resources when a state is closed -func CleanupCrypto(state *luajit.State) { - stateRngsMu.Lock() - delete(stateRngs, state) - stateRngsMu.Unlock() -} - -// cryptoHash generates hash digests using various algorithms -func cryptoHash(state *luajit.State) int { - if err := state.CheckMinArgs(2); err != nil { - return state.PushError("hash: %v", err) - } - - data, err := state.SafeToString(1) - if err != nil { - return state.PushError("hash: data must be string") - } - - algorithm, err := state.SafeToString(2) - if err != nil { - return state.PushError("hash: algorithm must be string") - } - - var h hash.Hash - switch algorithm { - case "md5": - h = md5.New() - case "sha1": - h = sha1.New() - case "sha256": - h = sha256.New() - case "sha512": - h = sha512.New() - default: - return state.PushError("unsupported algorithm: %s", algorithm) - } - - h.Write([]byte(data)) - hashBytes := h.Sum(nil) - - // Output format - outputFormat := "hex" - if state.GetTop() >= 3 { - if format, err := state.SafeToString(3); err == nil { - outputFormat = format - } - } - - switch outputFormat { - case "hex": - state.PushString(hex.EncodeToString(hashBytes)) - case "binary": - state.PushString(string(hashBytes)) - default: - state.PushString(hex.EncodeToString(hashBytes)) - } - - return 1 -} - -// cryptoHmac generates HMAC using various hash algorithms -func cryptoHmac(state *luajit.State) int { - if err := state.CheckMinArgs(3); err != nil { - return state.PushError("hmac: %v", err) - } - - data, err := state.SafeToString(1) - if err != nil { - return state.PushError("hmac: data must be string") - } - - key, err := state.SafeToString(2) - if err != nil { - return state.PushError("hmac: key must be string") - } - - algorithm, err := state.SafeToString(3) - if err != nil { - return state.PushError("hmac: algorithm must be string") - } - - var h func() hash.Hash - switch algorithm { - case "md5": - h = md5.New - case "sha1": - h = sha1.New - case "sha256": - h = sha256.New - case "sha512": - h = sha512.New - default: - return state.PushError("unsupported algorithm: %s", algorithm) - } - - mac := hmac.New(h, []byte(key)) - mac.Write([]byte(data)) - macBytes := mac.Sum(nil) - - // Output format - outputFormat := "hex" - if state.GetTop() >= 4 { - if format, err := state.SafeToString(4); err == nil { - outputFormat = format - } - } - - switch outputFormat { - case "hex": - state.PushString(hex.EncodeToString(macBytes)) - case "binary": - state.PushString(string(macBytes)) - default: - state.PushString(hex.EncodeToString(macBytes)) - } - - return 1 -} - -// cryptoUuid generates a random UUID v4 -func cryptoUuid(state *luajit.State) int { - uuid := make([]byte, 16) - if _, err := rand.Read(uuid); err != nil { - return state.PushError("uuid: generation error: %v", err) - } - - // Set version (4) and variant (RFC 4122) - uuid[6] = (uuid[6] & 0x0F) | 0x40 - uuid[8] = (uuid[8] & 0x3F) | 0x80 - - uuidStr := fmt.Sprintf("%x-%x-%x-%x-%x", - uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:]) - - state.PushString(uuidStr) - return 1 -} - -// cryptoRandomBytes generates random bytes -func cryptoRandomBytes(state *luajit.State) int { - if err := state.CheckMinArgs(1); err != nil { - return state.PushError("random_bytes: %v", err) - } - - length, err := state.SafeToNumber(1) - if err != nil { - return state.PushError("random_bytes: length must be number") - } - - if length <= 0 { - return state.PushError("random_bytes: length must be positive") - } - - // Check if secure - secure := true - if state.GetTop() >= 2 && state.IsBoolean(2) { - secure = state.ToBoolean(2) - } - - bytes := make([]byte, int(length)) - - if secure { - if _, err := rand.Read(bytes); err != nil { - return state.PushError("random_bytes: error: %v", err) - } - } else { - stateRngsMu.Lock() - stateRng, ok := stateRngs[state] - stateRngsMu.Unlock() - - if !ok { - return state.PushError("random_bytes: RNG not initialized") - } - - for i := range bytes { - bytes[i] = byte(stateRng.Uint64() & 0xFF) - } - } - - // Output format - outputFormat := "binary" - if state.GetTop() >= 3 { - if format, err := state.SafeToString(3); err == nil { - outputFormat = format - } - } - - switch outputFormat { - case "binary": - state.PushString(string(bytes)) - case "hex": - state.PushString(hex.EncodeToString(bytes)) - default: - state.PushString(string(bytes)) - } - - return 1 -} - -// cryptoRandomInt generates a random integer in range [min, max] -func cryptoRandomInt(state *luajit.State) int { - if err := state.CheckMinArgs(2); err != nil { - return state.PushError("random_int: %v", err) - } - - minVal, err := state.SafeToNumber(1) - if err != nil { - return state.PushError("random_int: min must be number") - } - - maxVal, err := state.SafeToNumber(2) - if err != nil { - return state.PushError("random_int: max must be number") - } - - min := int64(minVal) - max := int64(maxVal) - - if max <= min { - return state.PushError("random_int: max must be greater than min") - } - - // Check if secure - secure := true - if state.GetTop() >= 3 && state.IsBoolean(3) { - secure = state.ToBoolean(3) - } - - range_size := max - min + 1 - var result int64 - - if secure { - bytes := make([]byte, 8) - if _, err := rand.Read(bytes); err != nil { - return state.PushError("random_int: error: %v", err) - } - - val := binary.BigEndian.Uint64(bytes) - result = min + int64(val%uint64(range_size)) - } else { - stateRngsMu.Lock() - stateRng, ok := stateRngs[state] - stateRngsMu.Unlock() - - if !ok { - return state.PushError("random_int: RNG not initialized") - } - - result = min + int64(stateRng.Uint64()%uint64(range_size)) - } - - state.PushNumber(float64(result)) - return 1 -} - -// cryptoRandom implements math.random functionality -func cryptoRandom(state *luajit.State) int { - numArgs := state.GetTop() - - // Check if secure - secure := false - - // math.random() - return [0,1) - if numArgs == 0 { - if secure { - bytes := make([]byte, 8) - if _, err := rand.Read(bytes); err != nil { - return state.PushError("random: error: %v", err) - } - val := binary.BigEndian.Uint64(bytes) - state.PushNumber(float64(val) / float64(math.MaxUint64)) - } else { - stateRngsMu.Lock() - stateRng, ok := stateRngs[state] - stateRngsMu.Unlock() - - if !ok { - return state.PushError("random: RNG not initialized") - } - - state.PushNumber(float64(stateRng.Uint64()) / float64(math.MaxUint64)) - } - return 1 - } - - // math.random(n) - return integer [1,n] - if numArgs == 1 && state.IsNumber(1) { - n := int64(state.ToNumber(1)) - if n < 1 { - return state.PushError("random: upper bound must be >= 1") - } - - state.PushNumber(1) // min - state.PushNumber(float64(n)) // max - state.PushBoolean(secure) // secure flag - return cryptoRandomInt(state) - } - - // math.random(m, n) - return integer [m,n] - if numArgs >= 2 && state.IsNumber(1) && state.IsNumber(2) { - state.PushBoolean(secure) // secure flag - return cryptoRandomInt(state) - } - - return state.PushError("random: invalid arguments") -} - -// cryptoRandomSeed sets seed for non-secure RNG -func cryptoRandomSeed(state *luajit.State) int { - if err := state.CheckExactArgs(1); err != nil { - return state.PushError("randomseed: %v", err) - } - - seed, err := state.SafeToNumber(1) - if err != nil { - return state.PushError("randomseed: seed must be number") - } - - seedVal := uint64(seed) - stateRngsMu.Lock() - stateRngs[state] = mrand.NewPCG(seedVal, seedVal>>32) - stateRngsMu.Unlock() - - return 0 -} - -// OverrideLuaRandom replaces Lua's math.random with Go implementation -func OverrideLuaRandom(state *luajit.State) error { - if err := state.RegisterGoFunction("go_math_random", cryptoRandom); err != nil { - return err - } - - if err := state.RegisterGoFunction("go_math_randomseed", cryptoRandomSeed); err != nil { - return err - } - - // Replace original functions - return state.DoString(` - -- Save original functions - _G._original_math_random = math.random - _G._original_math_randomseed = math.randomseed - - -- Replace with Go implementations - math.random = go_math_random - math.randomseed = go_math_randomseed - - -- Clean up global namespace - go_math_random = nil - go_math_randomseed = nil - `) -} - -// generateToken creates a cryptographically secure random token -func generateToken(state *luajit.State) int { - // Get the length from the Lua arguments (default to 32) - length := 32 - if state.GetTop() >= 1 { - if lengthVal, err := state.SafeToNumber(1); err == nil { - length = int(lengthVal) - } - } - - // Enforce minimum length for security - if length < 16 { - length = 16 - } - - // Generate secure random bytes - tokenBytes := make([]byte, length) - if _, err := rand.Read(tokenBytes); err != nil { - logger.Errorf("Failed to generate secure token: %v", err) - state.PushString("") - return 1 // Return empty string on error - } - - // Encode as base64 - token := base64.RawURLEncoding.EncodeToString(tokenBytes) - - // Trim to requested length (base64 might be longer) - if len(token) > length { - token = token[:length] - } - - // Push the token to the Lua stack - state.PushString(token) - return 1 // One return value -} diff --git a/runner/lualibs/env.go b/runner/lualibs/env.go deleted file mode 100644 index d68a5d8..0000000 --- a/runner/lualibs/env.go +++ /dev/null @@ -1,333 +0,0 @@ -package lualibs - -import ( - "bufio" - "fmt" - "os" - "path/filepath" - "sort" - "strconv" - "strings" - "sync" - - "Moonshark/logger" - - "git.sharkk.net/Go/Color" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// EnvManager handles loading, storing, and saving environment variables -type EnvManager struct { - envPath string // Path to .env file - vars map[string]any // Environment variables in memory - mu sync.RWMutex // Thread-safe access -} - -// Global environment manager instance -var globalEnvManager *EnvManager - -// InitEnv initializes the environment manager with the given data directory -func InitEnv(dataDir string) error { - if dataDir == "" { - return fmt.Errorf("data directory cannot be empty") - } - - // Create data directory if it doesn't exist - if err := os.MkdirAll(dataDir, 0755); err != nil { - return fmt.Errorf("failed to create data directory: %w", err) - } - - envPath := filepath.Join(dataDir, ".env") - - globalEnvManager = &EnvManager{ - envPath: envPath, - vars: make(map[string]any), - } - - // Load existing .env file if it exists - if err := globalEnvManager.load(); err != nil { - logger.Warnf("Failed to load .env file: %v", err) - } - - count := len(globalEnvManager.vars) - if count > 0 { - logger.Infof("Environment loaded: %s vars from %s", - color.Yellow(fmt.Sprintf("%d", count)), - color.Yellow(envPath)) - } else { - logger.Infof("Environment initialized: %s", color.Yellow(envPath)) - } - - return nil -} - -// GetGlobalEnvManager returns the global environment manager instance -func GetGlobalEnvManager() *EnvManager { - return globalEnvManager -} - -// parseValue attempts to parse a string value into the appropriate type -func parseValue(value string) any { - // Try boolean first - if value == "true" { - return true - } - if value == "false" { - return false - } - - // Try number - if num, err := strconv.ParseFloat(value, 64); err == nil { - // Check if it's actually an integer - if num == float64(int64(num)) { - return int64(num) - } - return num - } - - // Default to string - return value -} - -// load reads the .env file and populates the vars map -func (e *EnvManager) load() error { - file, err := os.Open(e.envPath) - if os.IsNotExist(err) { - // File doesn't exist, start with empty env - return nil - } - if err != nil { - return fmt.Errorf("failed to open .env file: %w", err) - } - defer file.Close() - - e.mu.Lock() - defer e.mu.Unlock() - - scanner := bufio.NewScanner(file) - lineNum := 0 - - for scanner.Scan() { - lineNum++ - line := strings.TrimSpace(scanner.Text()) - - // Skip empty lines and comments - if line == "" || strings.HasPrefix(line, "#") { - continue - } - - // Parse key=value - parts := strings.SplitN(line, "=", 2) - if len(parts) != 2 { - logger.Warnf("Invalid .env line %d: %s", lineNum, line) - continue - } - - key := strings.TrimSpace(parts[0]) - value := strings.TrimSpace(parts[1]) - - // Remove quotes if present - if len(value) >= 2 { - if (strings.HasPrefix(value, "\"") && strings.HasSuffix(value, "\"")) || - (strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")) { - value = value[1 : len(value)-1] - } - } - - e.vars[key] = parseValue(value) - } - - return scanner.Err() -} - -// Save writes the current environment variables to the .env file -func (e *EnvManager) Save() error { - if e == nil { - return nil // No env manager initialized - } - - e.mu.RLock() - defer e.mu.RUnlock() - - file, err := os.Create(e.envPath) - if err != nil { - return fmt.Errorf("failed to create .env file: %w", err) - } - defer file.Close() - - // Sort keys for consistent output - keys := make([]string, 0, len(e.vars)) - for key := range e.vars { - keys = append(keys, key) - } - sort.Strings(keys) - - // Write header comment - fmt.Fprintln(file, "# env variables - generated automatically - you can edit this file") - fmt.Fprintln(file) - - // Write each variable - for _, key := range keys { - value := e.vars[key] - - // Convert value to string - var strValue string - switch v := value.(type) { - case string: - strValue = v - case bool: - strValue = strconv.FormatBool(v) - case int64: - strValue = strconv.FormatInt(v, 10) - case float64: - strValue = strconv.FormatFloat(v, 'g', -1, 64) - case nil: - continue // Skip nil values - default: - strValue = fmt.Sprintf("%v", v) - } - - // Quote values that contain spaces or special characters - if strings.ContainsAny(strValue, " \t\n\r\"'\\") { - strValue = fmt.Sprintf("\"%s\"", strings.ReplaceAll(strValue, "\"", "\\\"")) - } - - fmt.Fprintf(file, "%s=%s\n", key, strValue) - } - - logger.Debugf("Environment saved: %d vars to %s", len(e.vars), e.envPath) - return nil -} - -// Get retrieves an environment variable -func (e *EnvManager) Get(key string) (any, bool) { - if e == nil { - return nil, false - } - - e.mu.RLock() - defer e.mu.RUnlock() - - value, exists := e.vars[key] - return value, exists -} - -// Set stores an environment variable -func (e *EnvManager) Set(key string, value any) { - if e == nil { - return - } - - e.mu.Lock() - defer e.mu.Unlock() - - e.vars[key] = value -} - -// GetAll returns a copy of all environment variables -func (e *EnvManager) GetAll() map[string]any { - if e == nil { - return make(map[string]any) - } - - e.mu.RLock() - defer e.mu.RUnlock() - - result := make(map[string]any, len(e.vars)) - for k, v := range e.vars { - result[k] = v - } - return result -} - -// CleanupEnv saves the environment and cleans up resources -func CleanupEnv() error { - if globalEnvManager != nil { - return globalEnvManager.Save() - } - return nil -} - -// envGet Lua function to get an environment variable -func envGet(state *luajit.State) int { - if err := state.CheckExactArgs(1); err != nil { - state.PushNil() - return 1 - } - - key, err := state.SafeToString(1) - if err != nil { - state.PushNil() - return 1 - } - - if value, exists := globalEnvManager.Get(key); exists { - if err := state.PushValue(value); err != nil { - state.PushNil() - } - } else { - state.PushNil() - } - return 1 -} - -// envSet Lua function to set an environment variable -func envSet(state *luajit.State) int { - if err := state.CheckExactArgs(2); err != nil { - state.PushBoolean(false) - return 1 - } - - key, err := state.SafeToString(1) - if err != nil { - state.PushBoolean(false) - return 1 - } - - // Handle different value types from Lua - var value any - switch state.GetType(2) { - case luajit.TypeBoolean: - value = state.ToBoolean(2) - case luajit.TypeNumber: - value = state.ToNumber(2) - case luajit.TypeString: - value = state.ToString(2) - default: - // Try to convert to string as fallback - if str, err := state.SafeToString(2); err == nil { - value = str - } else { - state.PushBoolean(false) - return 1 - } - } - - globalEnvManager.Set(key, value) - state.PushBoolean(true) - return 1 -} - -// envGetAll Lua function to get all environment variables -func envGetAll(state *luajit.State) int { - vars := globalEnvManager.GetAll() - if err := state.PushValue(vars); err != nil { - state.PushNil() - } - return 1 -} - -// RegisterEnvFunctions registers environment functions with the Lua state -func RegisterEnvFunctions(state *luajit.State) error { - if err := state.RegisterGoFunction("__env_get", envGet); err != nil { - return err - } - if err := state.RegisterGoFunction("__env_set", envSet); err != nil { - return err - } - if err := state.RegisterGoFunction("__env_get_all", envGetAll); err != nil { - return err - } - return nil -} diff --git a/runner/lualibs/fs.go b/runner/lualibs/fs.go deleted file mode 100644 index 5d6d6e3..0000000 --- a/runner/lualibs/fs.go +++ /dev/null @@ -1,572 +0,0 @@ -package lualibs - -import ( - "errors" - "fmt" - "os" - "path/filepath" - "strings" - "time" - - "Moonshark/logger" - - "git.sharkk.net/Go/Color" - - lru "git.sharkk.net/Go/LRU" - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" - "github.com/golang/snappy" -) - -// Global filesystem path (set during initialization) -var fsBasePath string - -// Global file cache with compressed data -var fileCache *lru.LRUCache - -// Cache entry info for statistics/debugging -type cacheStats struct { - hits int64 - misses int64 -} - -var stats cacheStats - -// InitFS initializes the filesystem with the given base path -func InitFS(basePath string) error { - if basePath == "" { - return errors.New("filesystem base path cannot be empty") - } - - // Create the directory if it doesn't exist - if err := os.MkdirAll(basePath, 0755); err != nil { - return fmt.Errorf("failed to create filesystem directory: %w", err) - } - - // Store the absolute path - absPath, err := filepath.Abs(basePath) - if err != nil { - return fmt.Errorf("failed to get absolute path: %w", err) - } - - fsBasePath = absPath - - // Initialize file cache with 2000 entries (reasonable for most use cases) - fileCache = lru.NewLRUCache(2000) - - logger.Infof("Filesystem is g2g! %s", color.Yellow(fsBasePath)) - return nil -} - -// CleanupFS performs any necessary cleanup -func CleanupFS() { - if fileCache != nil { - fileCache.Clear() - logger.Infof( - "File cache cleared - %s hits, %s misses", - color.Yellow(fmt.Sprintf("%d", stats.hits)), - color.Red(fmt.Sprintf("%d", stats.misses)), - ) - } -} - -// ResolvePath resolves a given path relative to the filesystem base -// Returns the actual path and an error if the path tries to escape the sandbox -func ResolvePath(path string) (string, error) { - if fsBasePath == "" { - return "", errors.New("filesystem not initialized") - } - - // Clean the path to remove any .. or . components - cleanPath := filepath.Clean(path) - - // Replace backslashes with forward slashes for consistent handling - cleanPath = strings.ReplaceAll(cleanPath, "\\", "/") - - // Remove any leading / or drive letter to make it relative - cleanPath = strings.TrimPrefix(cleanPath, "/") - - // Remove drive letter on Windows (e.g. C:) - if len(cleanPath) >= 2 && cleanPath[1] == ':' { - cleanPath = cleanPath[2:] - } - - // Ensure the path doesn't contain .. to prevent escaping - if strings.Contains(cleanPath, "..") { - return "", errors.New("path cannot contain .. components") - } - - // Join with the base path - fullPath := filepath.Join(fsBasePath, cleanPath) - - // Verify the path is still within the base directory - if !strings.HasPrefix(fullPath, fsBasePath) { - return "", errors.New("path escapes the filesystem sandbox") - } - - return fullPath, nil -} - -// getCacheKey creates a cache key from path and modification time -func getCacheKey(fullPath string, modTime time.Time) string { - return fmt.Sprintf("%s:%d", fullPath, modTime.Unix()) -} - -// fsReadFile reads a file and returns its contents -func fsReadFile(state *luajit.State) int { - if err := state.CheckExactArgs(1); err != nil { - return state.PushError("fs.read_file: %v", err) - } - - path, err := state.SafeToString(1) - if err != nil { - return state.PushError("fs.read_file: path must be string") - } - - fullPath, err := ResolvePath(path) - if err != nil { - return state.PushError("fs.read_file: %v", err) - } - - // Get file info for cache key and validation - info, err := os.Stat(fullPath) - if err != nil { - return state.PushError("fs.read_file: %v", err) - } - - // Create cache key with path and modification time - cacheKey := getCacheKey(fullPath, info.ModTime()) - - // Try to get from cache first - if fileCache != nil { - if cachedData, exists := fileCache.Get(cacheKey); exists { - if compressedData, ok := cachedData.([]byte); ok { - // Decompress cached data - data, err := snappy.Decode(nil, compressedData) - if err == nil { - stats.hits++ - state.PushString(string(data)) - return 1 - } - // Cache corruption - continue to disk read - } - } - } - - // Cache miss or error - read from disk - stats.misses++ - data, err := os.ReadFile(fullPath) - if err != nil { - return state.PushError("fs.read_file: %v", err) - } - - // Compress and cache the data - if fileCache != nil { - compressedData := snappy.Encode(nil, data) - fileCache.Put(cacheKey, compressedData) - } - - state.PushString(string(data)) - return 1 -} - -// fsWriteFile writes data to a file -func fsWriteFile(state *luajit.State) int { - if err := state.CheckExactArgs(2); err != nil { - return state.PushError("fs.write_file: %v", err) - } - - path, err := state.SafeToString(1) - if err != nil { - return state.PushError("fs.write_file: path must be string") - } - - content, err := state.SafeToString(2) - if err != nil { - return state.PushError("fs.write_file: content must be string") - } - - fullPath, err := ResolvePath(path) - if err != nil { - return state.PushError("fs.write_file: %v", err) - } - - // Ensure the directory exists - dir := filepath.Dir(fullPath) - if err := os.MkdirAll(dir, 0755); err != nil { - return state.PushError("fs.write_file: failed to create directory: %v", err) - } - - if err := os.WriteFile(fullPath, []byte(content), 0644); err != nil { - return state.PushError("fs.write_file: %v", err) - } - - state.PushBoolean(true) - return 1 -} - -// fsAppendFile appends data to a file -func fsAppendFile(state *luajit.State) int { - if err := state.CheckExactArgs(2); err != nil { - return state.PushError("fs.append_file: %v", err) - } - - path, err := state.SafeToString(1) - if err != nil { - return state.PushError("fs.append_file: path must be string") - } - - content, err := state.SafeToString(2) - if err != nil { - return state.PushError("fs.append_file: content must be string") - } - - fullPath, err := ResolvePath(path) - if err != nil { - return state.PushError("fs.append_file: %v", err) - } - - // Ensure the directory exists - dir := filepath.Dir(fullPath) - if err := os.MkdirAll(dir, 0755); err != nil { - return state.PushError("fs.append_file: failed to create directory: %v", err) - } - - file, err := os.OpenFile(fullPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - return state.PushError("fs.append_file: %v", err) - } - defer file.Close() - - if _, err = file.Write([]byte(content)); err != nil { - return state.PushError("fs.append_file: %v", err) - } - - state.PushBoolean(true) - return 1 -} - -// fsExists checks if a file or directory exists -func fsExists(state *luajit.State) int { - if err := state.CheckExactArgs(1); err != nil { - return state.PushError("fs.exists: %v", err) - } - - path, err := state.SafeToString(1) - if err != nil { - return state.PushError("fs.exists: path must be string") - } - - fullPath, err := ResolvePath(path) - if err != nil { - return state.PushError("fs.exists: %v", err) - } - - _, err = os.Stat(fullPath) - state.PushBoolean(err == nil) - return 1 -} - -// fsRemoveFile removes a file -func fsRemoveFile(state *luajit.State) int { - if err := state.CheckExactArgs(1); err != nil { - return state.PushError("fs.remove_file: %v", err) - } - - path, err := state.SafeToString(1) - if err != nil { - return state.PushError("fs.remove_file: path must be string") - } - - fullPath, err := ResolvePath(path) - if err != nil { - return state.PushError("fs.remove_file: %v", err) - } - - // Check if it's a directory - info, err := os.Stat(fullPath) - if err != nil { - return state.PushError("fs.remove_file: %v", err) - } - - if info.IsDir() { - return state.PushError("fs.remove_file: cannot remove directory, use remove_dir instead") - } - - if err := os.Remove(fullPath); err != nil { - return state.PushError("fs.remove_file: %v", err) - } - - state.PushBoolean(true) - return 1 -} - -// fsGetInfo gets information about a file -func fsGetInfo(state *luajit.State) int { - if err := state.CheckExactArgs(1); err != nil { - return state.PushError("fs.get_info: %v", err) - } - - path, err := state.SafeToString(1) - if err != nil { - return state.PushError("fs.get_info: path must be string") - } - - fullPath, err := ResolvePath(path) - if err != nil { - return state.PushError("fs.get_info: %v", err) - } - - info, err := os.Stat(fullPath) - if err != nil { - return state.PushError("fs.get_info: %v", err) - } - - fileInfo := map[string]any{ - "name": info.Name(), - "size": info.Size(), - "mode": int(info.Mode()), - "mod_time": info.ModTime().Unix(), - "is_dir": info.IsDir(), - } - - if err := state.PushValue(fileInfo); err != nil { - return state.PushError("fs.get_info: %v", err) - } - - return 1 -} - -// fsMakeDir creates a directory -func fsMakeDir(state *luajit.State) int { - if err := state.CheckMinArgs(1); err != nil { - return state.PushError("fs.make_dir: %v", err) - } - - path, err := state.SafeToString(1) - if err != nil { - return state.PushError("fs.make_dir: path must be string") - } - - perm := os.FileMode(0755) - if state.GetTop() >= 2 { - if permVal, err := state.SafeToNumber(2); err == nil { - perm = os.FileMode(permVal) - } - } - - fullPath, err := ResolvePath(path) - if err != nil { - return state.PushError("fs.make_dir: %v", err) - } - - if err := os.MkdirAll(fullPath, perm); err != nil { - return state.PushError("fs.make_dir: %v", err) - } - - state.PushBoolean(true) - return 1 -} - -// fsListDir lists the contents of a directory -func fsListDir(state *luajit.State) int { - if err := state.CheckExactArgs(1); err != nil { - return state.PushError("fs.list_dir: %v", err) - } - - path, err := state.SafeToString(1) - if err != nil { - return state.PushError("fs.list_dir: path must be string") - } - - fullPath, err := ResolvePath(path) - if err != nil { - return state.PushError("fs.list_dir: %v", err) - } - - info, err := os.Stat(fullPath) - if err != nil { - return state.PushError("fs.list_dir: %v", err) - } - - if !info.IsDir() { - return state.PushError("fs.list_dir: not a directory") - } - - files, err := os.ReadDir(fullPath) - if err != nil { - return state.PushError("fs.list_dir: %v", err) - } - - // Create array of filenames - filenames := make([]string, len(files)) - for i, file := range files { - filenames[i] = file.Name() - } - - if err := state.PushValue(filenames); err != nil { - return state.PushError("fs.list_dir: %v", err) - } - - return 1 -} - -// fsRemoveDir removes a directory -func fsRemoveDir(state *luajit.State) int { - if err := state.CheckMinArgs(1); err != nil { - return state.PushError("fs.remove_dir: %v", err) - } - - path, err := state.SafeToString(1) - if err != nil { - return state.PushError("fs.remove_dir: path must be string") - } - - recursive := false - if state.GetTop() >= 2 { - recursive = state.ToBoolean(2) - } - - fullPath, err := ResolvePath(path) - if err != nil { - return state.PushError("fs.remove_dir: %v", err) - } - - info, err := os.Stat(fullPath) - if err != nil { - return state.PushError("fs.remove_dir: %v", err) - } - - if !info.IsDir() { - return state.PushError("fs.remove_dir: not a directory") - } - - if recursive { - err = os.RemoveAll(fullPath) - } else { - err = os.Remove(fullPath) - } - - if err != nil { - return state.PushError("fs.remove_dir: %v", err) - } - - state.PushBoolean(true) - return 1 -} - -// fsJoinPaths joins path components -func fsJoinPaths(state *luajit.State) int { - if err := state.CheckMinArgs(1); err != nil { - return state.PushError("fs.join_paths: %v", err) - } - - components := make([]string, state.GetTop()) - for i := 1; i <= state.GetTop(); i++ { - comp, err := state.SafeToString(i) - if err != nil { - return state.PushError("fs.join_paths: all arguments must be strings") - } - components[i-1] = comp - } - - result := filepath.Join(components...) - result = strings.ReplaceAll(result, "\\", "/") - - state.PushString(result) - return 1 -} - -// fsDirName returns the directory portion of a path -func fsDirName(state *luajit.State) int { - if err := state.CheckExactArgs(1); err != nil { - return state.PushError("fs.dir_name: %v", err) - } - - path, err := state.SafeToString(1) - if err != nil { - return state.PushError("fs.dir_name: path must be string") - } - - dir := filepath.Dir(path) - dir = strings.ReplaceAll(dir, "\\", "/") - - state.PushString(dir) - return 1 -} - -// fsBaseName returns the file name portion of a path -func fsBaseName(state *luajit.State) int { - if err := state.CheckExactArgs(1); err != nil { - return state.PushError("fs.base_name: %v", err) - } - - path, err := state.SafeToString(1) - if err != nil { - return state.PushError("fs.base_name: path must be string") - } - - base := filepath.Base(path) - state.PushString(base) - return 1 -} - -// fsExtension returns the file extension -func fsExtension(state *luajit.State) int { - if err := state.CheckExactArgs(1); err != nil { - return state.PushError("fs.extension: %v", err) - } - - path, err := state.SafeToString(1) - if err != nil { - return state.PushError("fs.extension: path must be string") - } - - ext := filepath.Ext(path) - state.PushString(ext) - return 1 -} - -// RegisterFSFunctions registers filesystem functions with the Lua state -func RegisterFSFunctions(state *luajit.State) error { - if err := state.RegisterGoFunction("__fs_read_file", fsReadFile); err != nil { - return err - } - if err := state.RegisterGoFunction("__fs_write_file", fsWriteFile); err != nil { - return err - } - if err := state.RegisterGoFunction("__fs_append_file", fsAppendFile); err != nil { - return err - } - if err := state.RegisterGoFunction("__fs_exists", fsExists); err != nil { - return err - } - if err := state.RegisterGoFunction("__fs_remove_file", fsRemoveFile); err != nil { - return err - } - if err := state.RegisterGoFunction("__fs_get_info", fsGetInfo); err != nil { - return err - } - if err := state.RegisterGoFunction("__fs_make_dir", fsMakeDir); err != nil { - return err - } - if err := state.RegisterGoFunction("__fs_list_dir", fsListDir); err != nil { - return err - } - if err := state.RegisterGoFunction("__fs_remove_dir", fsRemoveDir); err != nil { - return err - } - if err := state.RegisterGoFunction("__fs_join_paths", fsJoinPaths); err != nil { - return err - } - if err := state.RegisterGoFunction("__fs_dir_name", fsDirName); err != nil { - return err - } - if err := state.RegisterGoFunction("__fs_base_name", fsBaseName); err != nil { - return err - } - if err := state.RegisterGoFunction("__fs_extension", fsExtension); err != nil { - return err - } - - return nil -} diff --git a/runner/lualibs/http.go b/runner/lualibs/http.go deleted file mode 100644 index 88d6083..0000000 --- a/runner/lualibs/http.go +++ /dev/null @@ -1,227 +0,0 @@ -package lualibs - -import ( - "context" - "errors" - "fmt" - "net/url" - "strings" - "time" - - "github.com/goccy/go-json" - "github.com/valyala/bytebufferpool" - "github.com/valyala/fasthttp" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// Default HTTP client with sensible timeout -var defaultFastClient = fasthttp.Client{ - MaxConnsPerHost: 1024, - MaxIdleConnDuration: time.Minute, - ReadTimeout: 30 * time.Second, - WriteTimeout: 30 * time.Second, - DisableHeaderNamesNormalizing: true, -} - -// HTTPClientConfig contains client settings -type HTTPClientConfig struct { - MaxTimeout time.Duration // Maximum timeout for requests (0 = no limit) - DefaultTimeout time.Duration // Default request timeout - MaxResponseSize int64 // Maximum response size in bytes (0 = no limit) - AllowRemote bool // Whether to allow remote connections -} - -// DefaultHTTPClientConfig provides sensible defaults -var DefaultHTTPClientConfig = HTTPClientConfig{ - MaxTimeout: 60 * time.Second, - DefaultTimeout: 30 * time.Second, - MaxResponseSize: 10 * 1024 * 1024, // 10MB - AllowRemote: true, -} - -// RegisterHttpFunctions registers HTTP functions with the Lua state -func RegisterHttpFunctions(state *luajit.State) error { - if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil { - return err - } - - return nil -} - -// httpRequest makes an HTTP request and returns the result to Lua -func httpRequest(state *luajit.State) int { - if err := state.CheckMinArgs(2); err != nil { - return state.PushError("http.client.request: %v", err) - } - - // Get method and URL - method, err := state.SafeToString(1) - if err != nil { - return state.PushError("http.client.request: method must be string") - } - method = strings.ToUpper(method) - - urlStr, err := state.SafeToString(2) - if err != nil { - return state.PushError("http.client.request: url must be string") - } - - // Parse URL to check if it's valid - parsedURL, err := url.Parse(urlStr) - if err != nil { - return state.PushError("Invalid URL: %v", err) - } - - // Get client configuration - config := DefaultHTTPClientConfig - - // Check if remote connections are allowed - if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") { - return state.PushError("Remote connections are not allowed") - } - - // Use bytebufferpool for request and response - req := fasthttp.AcquireRequest() - resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseRequest(req) - defer fasthttp.ReleaseResponse(resp) - - // Set up request - req.Header.SetMethod(method) - req.SetRequestURI(urlStr) - req.Header.Set("User-Agent", "Moonshark/1.0") - - // Get body (optional) - if state.GetTop() >= 3 && !state.IsNil(3) { - if state.IsString(3) { - // String body - bodyStr, _ := state.SafeToString(3) - req.SetBodyString(bodyStr) - } else if state.IsTable(3) { - // Table body - convert to JSON - luaTable, err := state.ToTable(3) - if err != nil { - return state.PushError("Failed to parse body table: %v", err) - } - - // Use bytebufferpool for JSON serialization - buf := bytebufferpool.Get() - defer bytebufferpool.Put(buf) - - if err := json.NewEncoder(buf).Encode(luaTable); err != nil { - return state.PushError("Failed to convert body to JSON: %v", err) - } - - req.SetBody(buf.Bytes()) - req.Header.SetContentType("application/json") - } else { - return state.PushError("Body must be a string or table") - } - } - - // Process options (headers, timeout, etc.) - timeout := config.DefaultTimeout - if state.GetTop() >= 4 && !state.IsNil(4) && state.IsTable(4) { - if headers, ok := state.GetFieldTable(4, "headers"); ok { - if headerMap, ok := headers.(map[string]any); ok { - for name, value := range headerMap { - if valueStr, ok := value.(string); ok { - req.Header.Set(name, valueStr) - } - } - } - } - - // Get timeout - if timeoutVal := state.GetFieldNumber(4, "timeout", 0); timeoutVal > 0 { - requestTimeout := time.Duration(timeoutVal) * time.Second - - // Apply max timeout if configured - if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout { - timeout = config.MaxTimeout - } else { - timeout = requestTimeout - } - } - - // Process query parameters - if query, ok := state.GetFieldTable(4, "query"); ok { - args := req.URI().QueryArgs() - - if queryMap, ok := query.(map[string]any); ok { - for name, value := range queryMap { - switch v := value.(type) { - case string: - args.Add(name, v) - case int: - args.Add(name, fmt.Sprintf("%d", v)) - case float64: - args.Add(name, strings.TrimRight(strings.TrimRight(fmt.Sprintf("%.6f", v), "0"), ".")) - case bool: - if v { - args.Add(name, "true") - } else { - args.Add(name, "false") - } - } - } - } - } - } - - // Create context with timeout - _, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - // Execute request - err = defaultFastClient.DoTimeout(req, resp, timeout) - if err != nil { - errStr := "Request failed: " + err.Error() - if errors.Is(err, fasthttp.ErrTimeout) { - errStr = "Request timed out after " + timeout.String() - } - return state.PushError("%s", errStr) - } - - // Create response using TableBuilder - builder := state.NewTableBuilder() - - // Set status code and text - builder.SetNumber("status", float64(resp.StatusCode())) - builder.SetString("status_text", fasthttp.StatusMessage(resp.StatusCode())) - - // Set body - var respBody []byte - if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize { - // Make a limited copy - respBody = make([]byte, config.MaxResponseSize) - copy(respBody, resp.Body()) - } else { - respBody = resp.Body() - } - - builder.SetString("body", string(respBody)) - - // Parse body as JSON if content type is application/json - contentType := string(resp.Header.ContentType()) - if strings.Contains(contentType, "application/json") { - var jsonData any - if err := json.Unmarshal(respBody, &jsonData); err == nil { - builder.SetTable("json", jsonData) - } - } - - // Set headers - headers := make(map[string]string) - resp.Header.VisitAll(func(key, value []byte) { - headers[string(key)] = string(value) - }) - builder.SetTable("headers", headers) - - // Create ok field (true if status code is 2xx) - builder.SetBool("ok", resp.StatusCode() >= 200 && resp.StatusCode() < 300) - - builder.Build() - return 1 -} diff --git a/runner/lualibs/password.go b/runner/lualibs/password.go deleted file mode 100644 index aa53c84..0000000 --- a/runner/lualibs/password.go +++ /dev/null @@ -1,97 +0,0 @@ -package lualibs - -import ( - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" - "github.com/alexedwards/argon2id" -) - -// RegisterPasswordFunctions registers password-related functions in the Lua state -func RegisterPasswordFunctions(state *luajit.State) error { - if err := state.RegisterGoFunction("__password_hash", passwordHash); err != nil { - return err - } - if err := state.RegisterGoFunction("__password_verify", passwordVerify); err != nil { - return err - } - return nil -} - -// passwordHash implements the Argon2id password hashing using alexedwards/argon2id -func passwordHash(state *luajit.State) int { - if err := state.CheckMinArgs(1); err != nil { - return state.PushError("password_hash: %v", err) - } - - password, err := state.SafeToString(1) - if err != nil { - return state.PushError("password_hash: password must be string") - } - - params := &argon2id.Params{ - Memory: 128 * 1024, - Iterations: 4, - Parallelism: 4, - SaltLength: 16, - KeyLength: 32, - } - - if state.GetTop() >= 2 && state.IsTable(2) { - // Use new field getters with validation - if memory := state.GetFieldNumber(2, "memory", 0); memory > 0 { - params.Memory = max(uint32(memory), 8*1024) - } - - if iterations := state.GetFieldNumber(2, "iterations", 0); iterations > 0 { - params.Iterations = max(uint32(iterations), 1) - } - - if parallelism := state.GetFieldNumber(2, "parallelism", 0); parallelism > 0 { - params.Parallelism = max(uint8(parallelism), 1) - } - - if saltLength := state.GetFieldNumber(2, "salt_length", 0); saltLength > 0 { - params.SaltLength = max(uint32(saltLength), 8) - } - - if keyLength := state.GetFieldNumber(2, "key_length", 0); keyLength > 0 { - params.KeyLength = max(uint32(keyLength), 16) - } - } - - hash, err := argon2id.CreateHash(password, params) - if err != nil { - return state.PushError("password_hash: %v", err) - } - - state.PushString(hash) - return 1 -} - -// passwordVerify verifies a password against a hash -func passwordVerify(state *luajit.State) int { - if err := state.CheckExactArgs(2); err != nil { - state.PushBoolean(false) - return 1 - } - - password, err := state.SafeToString(1) - if err != nil { - state.PushBoolean(false) - return 1 - } - - hash, err := state.SafeToString(2) - if err != nil { - state.PushBoolean(false) - return 1 - } - - match, err := argon2id.ComparePasswordAndHash(password, hash) - if err != nil { - state.PushBoolean(false) - return 1 - } - - state.PushBoolean(match) - return 1 -} diff --git a/runner/lualibs/util.go b/runner/lualibs/util.go deleted file mode 100644 index 83ec874..0000000 --- a/runner/lualibs/util.go +++ /dev/null @@ -1,192 +0,0 @@ -package lualibs - -import ( - "encoding/base64" - "html" - "strings" - - "github.com/goccy/go-json" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// RegisterUtilFunctions registers utility functions with the Lua state -func RegisterUtilFunctions(state *luajit.State) error { - if err := state.RegisterGoFunction("__json_marshal", jsonMarshal); err != nil { - return err - } - - if err := state.RegisterGoFunction("__json_unmarshal", jsonUnmarshal); err != nil { - return err - } - - // HTML special chars - if err := state.RegisterGoFunction("__html_special_chars", htmlSpecialChars); err != nil { - return err - } - - // HTML entities - if err := state.RegisterGoFunction("__html_entities", htmlEntities); err != nil { - return err - } - - // Base64 encode - if err := state.RegisterGoFunction("__base64_encode", base64Encode); err != nil { - return err - } - - // Base64 decode - if err := state.RegisterGoFunction("__base64_decode", base64Decode); err != nil { - return err - } - - return nil -} - -// htmlSpecialChars converts special characters to HTML entities -func htmlSpecialChars(state *luajit.State) int { - if err := state.CheckExactArgs(1); err != nil { - state.PushNil() - return 1 - } - - input, err := state.SafeToString(1) - if err != nil { - state.PushNil() - return 1 - } - - result := html.EscapeString(input) - state.PushString(result) - return 1 -} - -// htmlEntities is a more comprehensive version of htmlSpecialChars -func htmlEntities(state *luajit.State) int { - if err := state.CheckExactArgs(1); err != nil { - state.PushNil() - return 1 - } - - input, err := state.SafeToString(1) - if err != nil { - state.PushNil() - return 1 - } - - // First use HTML escape for standard entities - result := html.EscapeString(input) - - // Additional entities beyond what html.EscapeString handles - replacements := map[string]string{ - "©": "©", - "®": "®", - "™": "™", - "€": "€", - "£": "£", - "¥": "¥", - "—": "—", - "–": "–", - "…": "…", - "•": "•", - "°": "°", - "±": "±", - "¼": "¼", - "½": "½", - "¾": "¾", - } - - for char, entity := range replacements { - result = strings.ReplaceAll(result, char, entity) - } - - state.PushString(result) - return 1 -} - -// base64Encode encodes a string to base64 -func base64Encode(state *luajit.State) int { - if err := state.CheckExactArgs(1); err != nil { - state.PushNil() - return 1 - } - - input, err := state.SafeToString(1) - if err != nil { - state.PushNil() - return 1 - } - - result := base64.StdEncoding.EncodeToString([]byte(input)) - state.PushString(result) - return 1 -} - -// base64Decode decodes a base64 string -func base64Decode(state *luajit.State) int { - if err := state.CheckExactArgs(1); err != nil { - state.PushNil() - return 1 - } - - input, err := state.SafeToString(1) - if err != nil { - state.PushNil() - return 1 - } - - result, err := base64.StdEncoding.DecodeString(input) - if err != nil { - state.PushNil() - return 1 - } - - state.PushString(string(result)) - return 1 -} - -// jsonMarshal converts a Lua value to a JSON string with validation -func jsonMarshal(state *luajit.State) int { - if err := state.CheckExactArgs(1); err != nil { - return state.PushError("json marshal: %v", err) - } - - value, err := state.ToTable(1) - if err != nil { - // Try as generic value if not a table - value, err = state.ToValue(1) - if err != nil { - return state.PushError("json marshal error: %v", err) - } - } - - bytes, err := json.Marshal(value) - if err != nil { - return state.PushError("json marshal error: %v", err) - } - - state.PushString(string(bytes)) - return 1 -} - -// jsonUnmarshal converts a JSON string to a Lua value with validation -func jsonUnmarshal(state *luajit.State) int { - if err := state.CheckExactArgs(1); err != nil { - return state.PushError("json unmarshal: %v", err) - } - - jsonStr, err := state.SafeToString(1) - if err != nil { - return state.PushError("json unmarshal: expected string, got %s", state.GetType(1)) - } - - var value any - if err := json.Unmarshal([]byte(jsonStr), &value); err != nil { - return state.PushError("json unmarshal error: %v", err) - } - - if err := state.PushValue(value); err != nil { - return state.PushError("json unmarshal error: %v", err) - } - return 1 -} diff --git a/runner/moduleLoader.go b/runner/moduleLoader.go deleted file mode 100644 index e806612..0000000 --- a/runner/moduleLoader.go +++ /dev/null @@ -1,292 +0,0 @@ -package runner - -import ( - "fmt" - "os" - "path/filepath" - "strings" - "sync" - - "Moonshark/logger" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -type ModuleConfig struct { - ScriptDir string - LibDirs []string -} - -type ModuleLoader struct { - config *ModuleConfig - pathCache map[string]string // For reverse lookups (path -> module name) - debug bool - mu sync.RWMutex -} - -func NewModuleLoader(config *ModuleConfig) *ModuleLoader { - if config == nil { - config = &ModuleConfig{} - } - - return &ModuleLoader{ - config: config, - pathCache: make(map[string]string), - } -} - -func (l *ModuleLoader) EnableDebug() { - l.debug = true -} - -func (l *ModuleLoader) SetScriptDir(dir string) { - l.mu.Lock() - defer l.mu.Unlock() - l.config.ScriptDir = dir -} - -func (l *ModuleLoader) debugLog(format string, args ...any) { - if l.debug { - logger.Debugf("ModuleLoader "+format, args...) - } -} - -func (l *ModuleLoader) SetupRequire(state *luajit.State) error { - // Set package.path - paths := l.getSearchPaths() - pathStr := strings.Join(paths, ";") - - return state.DoString(`package.path = "` + escapeLuaString(pathStr) + `"`) -} - -func (l *ModuleLoader) getSearchPaths() []string { - var paths []string - seen := make(map[string]bool) - - // Script directory first - if l.config.ScriptDir != "" { - if absPath, err := filepath.Abs(l.config.ScriptDir); err == nil && !seen[absPath] { - paths = append(paths, filepath.Join(absPath, "?.lua")) - seen[absPath] = true - } - } - - // Library directories - for _, dir := range l.config.LibDirs { - if dir == "" { - continue - } - if absPath, err := filepath.Abs(dir); err == nil && !seen[absPath] { - paths = append(paths, filepath.Join(absPath, "?.lua")) - seen[absPath] = true - } - } - - return paths -} - -func (l *ModuleLoader) PreloadModules(state *luajit.State) error { - l.mu.Lock() - defer l.mu.Unlock() - - // Reset caches - l.pathCache = make(map[string]string) - - // Clear non-core modules - err := state.DoString(` - local core = {string=1, table=1, math=1, os=1, package=1, io=1, coroutine=1, debug=1, _G=1} - for name in pairs(package.loaded) do - if not core[name] then package.loaded[name] = nil end - end - package.preload = {} - `) - if err != nil { - return err - } - - // Scan and preload modules - for _, dir := range l.config.LibDirs { - if err := l.scanDirectory(state, dir); err != nil { - return err - } - } - - // Install simplified require - return state.DoString(` - function __setup_require(env) - env.require = function(modname) - if package.loaded[modname] then - return package.loaded[modname] - end - - local loader = package.preload[modname] - if loader then - setfenv(loader, env) - local result = loader() or true - package.loaded[modname] = result - return result - end - - -- Standard path search - for path in package.path:gmatch("[^;]+") do - local file = path:gsub("?", modname:gsub("%.", "/")) - local chunk = loadfile(file) - if chunk then - setfenv(chunk, env) - local result = chunk() or true - package.loaded[modname] = result - return result - end - end - - error("module '" .. modname .. "' not found", 2) - end - return env - end - `) -} - -func (l *ModuleLoader) scanDirectory(state *luajit.State, dir string) error { - if dir == "" { - return nil - } - - absDir, err := filepath.Abs(dir) - if err != nil { - return nil - } - - l.debugLog("Scanning directory: %s", absDir) - - return filepath.Walk(absDir, func(path string, info os.FileInfo, err error) error { - if err != nil || info.IsDir() || !strings.HasSuffix(path, ".lua") { - return nil - } - - relPath, err := filepath.Rel(absDir, path) - if err != nil || strings.HasPrefix(relPath, "..") { - return nil - } - - // Convert to module name - modName := strings.TrimSuffix(relPath, ".lua") - modName = strings.ReplaceAll(modName, string(filepath.Separator), ".") - - l.debugLog("Found module: %s at %s", modName, path) - l.pathCache[modName] = path - - // Load and compile module - content, err := os.ReadFile(path) - if err != nil { - l.debugLog("Failed to read %s: %v", path, err) - return nil - } - - if err := state.LoadString(string(content)); err != nil { - l.debugLog("Failed to compile %s: %v", path, err) - return nil - } - - // Store in package.preload - state.GetGlobal("package") - state.GetField(-1, "preload") - state.PushString(modName) - state.PushCopy(-4) // Copy compiled function - state.SetTable(-3) - state.Pop(2) // Pop package and preload - state.Pop(1) // Pop function - - return nil - }) -} - -func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) { - l.mu.RLock() - defer l.mu.RUnlock() - - absPath, err := filepath.Abs(path) - if err != nil { - absPath = filepath.Clean(path) - } - - // Direct lookup - for modName, modPath := range l.pathCache { - if modPath == absPath { - return modName, true - } - } - - // Construct from lib dirs - for _, dir := range l.config.LibDirs { - absDir, err := filepath.Abs(dir) - if err != nil { - continue - } - - relPath, err := filepath.Rel(absDir, absPath) - if err != nil || strings.HasPrefix(relPath, "..") || !strings.HasSuffix(relPath, ".lua") { - continue - } - - modName := strings.TrimSuffix(relPath, ".lua") - modName = strings.ReplaceAll(modName, string(filepath.Separator), ".") - return modName, true - } - - return "", false -} - -func (l *ModuleLoader) RefreshModule(state *luajit.State, moduleName string) error { - l.mu.Lock() - defer l.mu.Unlock() - - path, exists := l.pathCache[moduleName] - if !exists { - return fmt.Errorf("module %s not found", moduleName) - } - - l.debugLog("Refreshing module: %s", moduleName) - - content, err := os.ReadFile(path) - if err != nil { - return fmt.Errorf("failed to read module: %w", err) - } - - // Compile new version - if err := state.LoadString(string(content)); err != nil { - return fmt.Errorf("failed to compile module: %w", err) - } - - // Update package.preload - state.GetGlobal("package") - state.GetField(-1, "preload") - state.PushString(moduleName) - state.PushCopy(-4) // Copy function - state.SetTable(-3) - state.Pop(2) // Pop package and preload - state.Pop(1) // Pop function - - // Clear from loaded - state.DoString(`package.loaded["` + escapeLuaString(moduleName) + `"] = nil`) - - l.debugLog("Successfully refreshed: %s", moduleName) - return nil -} - -func (l *ModuleLoader) RefreshModuleByPath(state *luajit.State, filePath string) error { - moduleName, exists := l.GetModuleByPath(filePath) - if !exists { - return fmt.Errorf("no module found for path: %s", filePath) - } - return l.RefreshModule(state, moduleName) -} - -func escapeLuaString(s string) string { - return strings.NewReplacer( - `\`, `\\`, - `"`, `\"`, - "\n", `\n`, - "\r", `\r`, - "\t", `\t`, - ).Replace(s) -} diff --git a/runner/response.go b/runner/response.go deleted file mode 100644 index 751e554..0000000 --- a/runner/response.go +++ /dev/null @@ -1,56 +0,0 @@ -package runner - -import ( - "sync" - - "github.com/valyala/fasthttp" -) - -// Response represents a unified response from script execution -type Response struct { - // Basic properties - Body any // Body content (any type) - Metadata map[string]any // Additional metadata - - // HTTP specific properties - Status int // HTTP status code - Headers map[string]string // HTTP headers - Cookies []*fasthttp.Cookie // HTTP cookies - - // Session information - SessionData map[string]any -} - -// Response pool to reduce allocations -var responsePool = sync.Pool{ - New: func() any { - return &Response{ - Status: 200, - Headers: make(map[string]string, 8), - Metadata: make(map[string]any, 8), - Cookies: make([]*fasthttp.Cookie, 0, 4), - SessionData: make(map[string]any, 8), - } - }, -} - -// NewResponse creates a new response object from the pool -func NewResponse() *Response { - return responsePool.Get().(*Response) -} - -// Release returns a response to the pool after cleaning it -func ReleaseResponse(resp *Response) { - if resp == nil { - return - } - - resp.Body = nil - resp.Status = 200 - resp.Headers = make(map[string]string, 8) - resp.Metadata = make(map[string]any, 8) - resp.Cookies = resp.Cookies[:0] - resp.SessionData = make(map[string]any, 8) - - responsePool.Put(resp) -} diff --git a/runner/runner.go b/runner/runner.go deleted file mode 100644 index 0aafa2d..0000000 --- a/runner/runner.go +++ /dev/null @@ -1,335 +0,0 @@ -package runner - -import ( - "errors" - "fmt" - "os" - "path/filepath" - "runtime" - "sync" - "sync/atomic" - "time" - - "Moonshark/config" - "Moonshark/logger" - "Moonshark/runner/lualibs" - "Moonshark/runner/sqlite" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -var emptyMap = make(map[string]any) - -var ( - ErrRunnerClosed = errors.New("lua runner is closed") - ErrTimeout = errors.New("operation timed out") - ErrStateNotReady = errors.New("lua state not ready") -) - -type State struct { - L *luajit.State - sandbox *Sandbox - index int - inUse atomic.Bool -} - -type Runner struct { - states []*State - statePool chan int - poolSize int - moduleLoader *ModuleLoader - isRunning atomic.Bool - mu sync.RWMutex - scriptDir string - - // Pre-allocated pools for HTTP processing - ctxPool sync.Pool - paramsPool sync.Pool -} - -func NewRunner(cfg *config.Config, poolSize int) (*Runner, error) { - if poolSize <= 0 && cfg.Runner.PoolSize <= 0 { - poolSize = runtime.GOMAXPROCS(0) - } - - moduleConfig := &ModuleConfig{ - LibDirs: cfg.Dirs.Libs, - } - - r := &Runner{ - poolSize: poolSize, - moduleLoader: NewModuleLoader(moduleConfig), - ctxPool: sync.Pool{ - New: func() any { return make(map[string]any, 8) }, - }, - paramsPool: sync.Pool{ - New: func() any { return make(map[string]any, 4) }, - }, - } - - sqlite.InitSQLite(cfg.Dirs.Data) - sqlite.SetSQLitePoolSize(poolSize) - lualibs.InitFS(cfg.Dirs.FS) - lualibs.InitEnv(cfg.Dirs.Data) - - r.states = make([]*State, poolSize) - r.statePool = make(chan int, poolSize) - - if err := r.initStates(); err != nil { - sqlite.CleanupSQLite() - return nil, err - } - - r.isRunning.Store(true) - return r, nil -} - -func (r *Runner) Execute(bytecode []byte, ctx ExecutionContext) (*Response, error) { - if !r.isRunning.Load() { - return nil, ErrRunnerClosed - } - - var stateIndex int - select { - case stateIndex = <-r.statePool: - case <-time.After(time.Second): - return nil, ErrTimeout - } - - state := r.states[stateIndex] - if state == nil { - r.statePool <- stateIndex - return nil, ErrStateNotReady - } - - state.inUse.Store(true) - defer func() { - state.inUse.Store(false) - if r.isRunning.Load() { - select { - case r.statePool <- stateIndex: - default: - } - } - }() - - return state.sandbox.Execute(state.L, bytecode, ctx, state.index) -} - -func (r *Runner) initStates() error { - logger.Infof("[LuaRunner] Creating %d states...", r.poolSize) - - for i := range r.poolSize { - state, err := r.createState(i) - if err != nil { - return err - } - r.states[i] = state - r.statePool <- i - } - return nil -} - -func (r *Runner) createState(index int) (*State, error) { - L := luajit.New(true) - if L == nil { - return nil, errors.New("failed to create Lua state") - } - - sb := NewSandbox() - if err := sb.Setup(L, index, index == 0); err != nil { - L.Cleanup() - L.Close() - return nil, err - } - - if err := r.moduleLoader.SetupRequire(L); err != nil { - L.Cleanup() - L.Close() - return nil, err - } - - if err := r.moduleLoader.PreloadModules(L); err != nil { - L.Cleanup() - L.Close() - return nil, err - } - - return &State{L: L, sandbox: sb, index: index}, nil -} - -func (r *Runner) Close() error { - r.mu.Lock() - defer r.mu.Unlock() - - if !r.isRunning.Load() { - return ErrRunnerClosed - } - r.isRunning.Store(false) - - // Drain pool - for { - select { - case <-r.statePool: - default: - goto cleanup - } - } - -cleanup: - // Wait for states to finish - timeout := time.Now().Add(10 * time.Second) - for time.Now().Before(timeout) { - allIdle := true - for _, state := range r.states { - if state != nil && state.inUse.Load() { - allIdle = false - break - } - } - if allIdle { - break - } - time.Sleep(10 * time.Millisecond) - } - - // Close states - for i, state := range r.states { - if state != nil { - state.L.Cleanup() - state.L.Close() - r.states[i] = nil - } - } - - lualibs.CleanupFS() - sqlite.CleanupSQLite() - lualibs.CleanupEnv() - return nil -} - -// NotifyFileChanged alerts the runner about file changes -func (r *Runner) NotifyFileChanged(filePath string) bool { - logger.Debugf("Runner notified of file change: %s", filePath) - - module, isModule := r.moduleLoader.GetModuleByPath(filePath) - if isModule { - logger.Debugf("Refreshing module: %s", module) - return r.RefreshModule(module) - } - - logger.Debugf("File change noted but no refresh needed: %s", filePath) - return true -} - -// RefreshModule refreshes a specific module across all states -func (r *Runner) RefreshModule(moduleName string) bool { - r.mu.RLock() - defer r.mu.RUnlock() - - if !r.isRunning.Load() { - return false - } - - logger.Debugf("Refreshing module: %s", moduleName) - - success := true - for _, state := range r.states { - if state == nil || state.inUse.Load() { - continue - } - - if err := r.moduleLoader.RefreshModule(state.L, moduleName); err != nil { - success = false - logger.Debugf("Failed to refresh module %s in state %d: %v", moduleName, state.index, err) - } - } - - if success { - logger.Debugf("Successfully refreshed module: %s", moduleName) - } - - return success -} - -// RunScriptFile loads, compiles and executes a Lua script file -func (r *Runner) RunScriptFile(filePath string) (*Response, error) { - if !r.isRunning.Load() { - return nil, ErrRunnerClosed - } - - if _, err := os.Stat(filePath); os.IsNotExist(err) { - return nil, fmt.Errorf("script file not found: %s", filePath) - } - - content, err := os.ReadFile(filePath) - if err != nil { - return nil, fmt.Errorf("failed to read file: %w", err) - } - - absPath, err := filepath.Abs(filePath) - if err != nil { - return nil, fmt.Errorf("failed to get absolute path: %w", err) - } - scriptDir := filepath.Dir(absPath) - - r.mu.Lock() - prevScriptDir := r.scriptDir - r.scriptDir = scriptDir - r.moduleLoader.SetScriptDir(scriptDir) - r.mu.Unlock() - - defer func() { - r.mu.Lock() - r.scriptDir = prevScriptDir - r.moduleLoader.SetScriptDir(prevScriptDir) - r.mu.Unlock() - }() - - // Get state from pool - var stateIndex int - select { - case stateIndex = <-r.statePool: - case <-time.After(5 * time.Second): - return nil, ErrTimeout - } - - state := r.states[stateIndex] - if state == nil { - r.statePool <- stateIndex - return nil, ErrStateNotReady - } - - state.inUse.Store(true) - - defer func() { - state.inUse.Store(false) - if r.isRunning.Load() { - select { - case r.statePool <- stateIndex: - default: - } - } - }() - - // Compile script - bytecode, err := state.L.CompileBytecode(string(content), filepath.Base(absPath)) - if err != nil { - return nil, fmt.Errorf("compilation error: %w", err) - } - - // Create simple context for script execution - ctx := NewContext() - defer ctx.Release() - - ctx.Set("_script_path", absPath) - ctx.Set("_script_dir", scriptDir) - - // Execute script - response, err := state.sandbox.Execute(state.L, bytecode, ctx, state.index) - if err != nil { - return nil, fmt.Errorf("execution error: %w", err) - } - - return response, nil -} diff --git a/runner/sandbox.go b/runner/sandbox.go deleted file mode 100644 index 5393390..0000000 --- a/runner/sandbox.go +++ /dev/null @@ -1,220 +0,0 @@ -package runner - -import ( - "Moonshark/runner/lualibs" - "Moonshark/runner/sqlite" - "fmt" - - "maps" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" - "github.com/valyala/fasthttp" -) - -// Sandbox provides a secure execution environment for Lua scripts -type Sandbox struct { - executorBytecode []byte -} - -// NewSandbox creates a new sandbox environment -func NewSandbox() *Sandbox { - return &Sandbox{} -} - -// Setup initializes the sandbox in a Lua state -func (s *Sandbox) Setup(state *luajit.State, stateIndex int, verbose bool) error { - // Load all embedded modules and sandbox - if err := loadSandboxIntoState(state, verbose); err != nil { - return fmt.Errorf("failed to load sandbox: %w", err) - } - - // Set the state index as a global variable - state.PushNumber(float64(stateIndex)) - state.SetGlobal("__STATE_INDEX") - - // Pre-compile the executor function for reuse - executorCode := `return __execute` - bytecode, err := state.CompileBytecode(executorCode, "executor") - if err != nil { - return fmt.Errorf("failed to compile executor: %w", err) - } - s.executorBytecode = bytecode - - // Register native functions - if err := s.registerCoreFunctions(state); err != nil { - return err - } - - return nil -} - -// Execute runs a Lua script in the sandbox with the given context -func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx ExecutionContext, stateIndex int) (*Response, error) { - // Load script and executor - if err := state.LoadBytecode(bytecode, "script"); err != nil { - return nil, fmt.Errorf("failed to load bytecode: %w", err) - } - - if err := state.LoadBytecode(s.executorBytecode, "executor"); err != nil { - state.Pop(1) - return nil, fmt.Errorf("failed to load executor: %w", err) - } - - // Get __execute function - if err := state.Call(0, 1); err != nil { - state.Pop(1) - return nil, fmt.Errorf("failed to get executor: %w", err) - } - - // Prepare response object - response := map[string]any{ - "status": 200, - "headers": make(map[string]string), - "cookies": []any{}, - "metadata": make(map[string]any), - "session": make(map[string]any), - "flash": make(map[string]any), - } - - // Call __execute(script_func, ctx, response) - state.PushCopy(-2) // script function - state.PushValue(ctx.ToMap()) - state.PushValue(response) - - if err := state.Call(3, 1); err != nil { - state.Pop(1) - return nil, fmt.Errorf("script execution failed: %w", err) - } - - // Extract result - result, _ := state.ToValue(-1) - state.Pop(2) // Clean up - - sqlite.CleanupStateConnection(stateIndex) - - var modifiedResponse map[string]any - var scriptResult any - - if arr, ok := result.([]any); ok && len(arr) >= 2 { - scriptResult = arr[0] - if resp, ok := arr[1].(map[string]any); ok { - modifiedResponse = resp - } - } - - if modifiedResponse == nil { - scriptResult = result - modifiedResponse = response - } - - return s.buildResponse(modifiedResponse, scriptResult), nil -} - -// buildResponse converts the Lua response object to a Go Response -func (s *Sandbox) buildResponse(luaResp map[string]any, body any) *Response { - resp := NewResponse() - resp.Body = body - - // Extract status - if status, ok := luaResp["status"].(float64); ok { - resp.Status = int(status) - } else if status, ok := luaResp["status"].(int); ok { - resp.Status = status - } - - // Extract headers - if headers, ok := luaResp["headers"].(map[string]any); ok { - for k, v := range headers { - if str, ok := v.(string); ok { - resp.Headers[k] = str - } - } - } - - // Extract cookies - if cookies, ok := luaResp["cookies"].([]any); ok { - for _, cookieData := range cookies { - if cookieMap, ok := cookieData.(map[string]any); ok { - cookie := fasthttp.AcquireCookie() - - if name, ok := cookieMap["name"].(string); ok && name != "" { - cookie.SetKey(name) - if value, ok := cookieMap["value"].(string); ok { - cookie.SetValue(value) - } - if path, ok := cookieMap["path"].(string); ok { - cookie.SetPath(path) - } - if domain, ok := cookieMap["domain"].(string); ok { - cookie.SetDomain(domain) - } - if httpOnly, ok := cookieMap["http_only"].(bool); ok { - cookie.SetHTTPOnly(httpOnly) - } - if secure, ok := cookieMap["secure"].(bool); ok { - cookie.SetSecure(secure) - } - if maxAge, ok := cookieMap["max_age"].(float64); ok { - cookie.SetMaxAge(int(maxAge)) - } else if maxAge, ok := cookieMap["max_age"].(int); ok { - cookie.SetMaxAge(maxAge) - } - - resp.Cookies = append(resp.Cookies, cookie) - } else { - fasthttp.ReleaseCookie(cookie) - } - } - } - } - - // Extract metadata - simplified - if metadata, ok := luaResp["metadata"].(map[string]any); ok { - maps.Copy(resp.Metadata, metadata) - } - - // Extract session data - simplified - if session, ok := luaResp["session"].(map[string]any); ok { - maps.Copy(resp.SessionData, session) - } - - // Extract flash data and add to metadata for processing by server - if flash, ok := luaResp["flash"].(map[string]any); ok && len(flash) > 0 { - resp.Metadata["flash"] = flash - } - - return resp -} - -// registerCoreFunctions registers all built-in functions in the Lua state -func (s *Sandbox) registerCoreFunctions(state *luajit.State) error { - if err := lualibs.RegisterCryptoFunctions(state); err != nil { - return err - } - - if err := lualibs.RegisterEnvFunctions(state); err != nil { - return err - } - - if err := lualibs.RegisterFSFunctions(state); err != nil { - return err - } - - if err := lualibs.RegisterHttpFunctions(state); err != nil { - return err - } - - if err := lualibs.RegisterPasswordFunctions(state); err != nil { - return err - } - - if err := sqlite.RegisterSQLiteFunctions(state); err != nil { - return err - } - - if err := lualibs.RegisterUtilFunctions(state); err != nil { - return err - } - - return nil -} diff --git a/runner/sqlite/sqlite.go b/runner/sqlite/sqlite.go deleted file mode 100644 index 4f8f46e..0000000 --- a/runner/sqlite/sqlite.go +++ /dev/null @@ -1,466 +0,0 @@ -package sqlite - -import ( - "context" - "fmt" - "path/filepath" - "strings" - "sync" - "time" - - sqlite "zombiezen.com/go/sqlite" - "zombiezen.com/go/sqlite/sqlitex" - - "Moonshark/logger" - - "git.sharkk.net/Go/Color" - - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -var ( - dbPools = make(map[string]*sqlitex.Pool) - poolsMu sync.RWMutex - dataDir string - poolSize = 8 - connTimeout = 5 * time.Second - - // Per-state connection cache - stateConns = make(map[string]*stateConn) - stateConnsMu sync.RWMutex -) - -// stateConn tracks a connection and its origin pool -type stateConn struct { - conn *sqlite.Conn - pool *sqlitex.Pool -} - -func InitSQLite(dir string) { - dataDir = dir - logger.Infof("SQLite is g2g! %s", color.Yellow(dir)) -} - -func SetSQLitePoolSize(size int) { - if size > 0 { - poolSize = size - } -} - -func CleanupSQLite() { - poolsMu.Lock() - defer poolsMu.Unlock() - - // Return all cached connections to their pools - stateConnsMu.Lock() - for _, sc := range stateConns { - if sc.pool != nil && sc.conn != nil { - sc.pool.Put(sc.conn) - } - } - stateConns = make(map[string]*stateConn) - stateConnsMu.Unlock() - - for name, pool := range dbPools { - if err := pool.Close(); err != nil { - logger.Errorf("Failed to close database %s: %v", name, err) - } - } - - dbPools = make(map[string]*sqlitex.Pool) - logger.Debugf("SQLite connections closed") -} - -func getPool(dbName string) (*sqlitex.Pool, error) { - dbName = filepath.Base(dbName) - if dbName == "" || dbName[0] == '.' { - return nil, fmt.Errorf("invalid database name") - } - - poolsMu.RLock() - pool, exists := dbPools[dbName] - if exists { - poolsMu.RUnlock() - return pool, nil - } - poolsMu.RUnlock() - - poolsMu.Lock() - defer poolsMu.Unlock() - - if pool, exists = dbPools[dbName]; exists { - return pool, nil - } - - dbPath := filepath.Join(dataDir, dbName+".db") - pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{ - PoolSize: poolSize, - PrepareConn: func(conn *sqlite.Conn) error { - pragmas := []string{ - "PRAGMA journal_mode = WAL", - "PRAGMA synchronous = NORMAL", - "PRAGMA cache_size = 1000", - "PRAGMA foreign_keys = ON", - "PRAGMA temp_store = MEMORY", - } - for _, pragma := range pragmas { - if err := sqlitex.ExecuteTransient(conn, pragma, nil); err != nil { - return err - } - } - return nil - }, - }) - if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) - } - - dbPools[dbName] = pool - logger.Debugf("Created SQLite pool for %s (size: %d)", dbName, poolSize) - return pool, nil -} - -// getStateConnection gets or creates a reusable connection for the state+db -func getStateConnection(stateIndex int, dbName string) (*sqlite.Conn, error) { - connKey := fmt.Sprintf("%d-%s", stateIndex, dbName) - - stateConnsMu.RLock() - sc, exists := stateConns[connKey] - stateConnsMu.RUnlock() - - if exists && sc.conn != nil { - return sc.conn, nil - } - - // Get new connection from pool - pool, err := getPool(dbName) - if err != nil { - return nil, err - } - - ctx, cancel := context.WithTimeout(context.Background(), connTimeout) - defer cancel() - - conn, err := pool.Take(ctx) - if err != nil { - return nil, fmt.Errorf("connection timeout: %w", err) - } - - // Cache it with pool reference - stateConnsMu.Lock() - stateConns[connKey] = &stateConn{ - conn: conn, - pool: pool, - } - stateConnsMu.Unlock() - - return conn, nil -} - -func sqlQuery(state *luajit.State) int { - if err := state.CheckMinArgs(3); err != nil { - return state.PushError("sqlite.query: %v", err) - } - - dbName, err := state.SafeToString(1) - if err != nil { - return state.PushError("sqlite.query: database name must be string") - } - - query, err := state.SafeToString(2) - if err != nil { - return state.PushError("sqlite.query: query must be string") - } - - stateIndex := int(state.ToNumber(-1)) - - conn, err := getStateConnection(stateIndex, dbName) - if err != nil { - return state.PushError("sqlite.query: %v", err) - } - - var execOpts sqlitex.ExecOptions - rows := make([]any, 0, 16) - - if state.GetTop() >= 4 && !state.IsNil(3) { - if err := setupParams(state, 3, &execOpts); err != nil { - return state.PushError("sqlite.query: %v", err) - } - } - - execOpts.ResultFunc = func(stmt *sqlite.Stmt) error { - row := make(map[string]any) - colCount := stmt.ColumnCount() - - for i := range colCount { - colName := stmt.ColumnName(i) - switch stmt.ColumnType(i) { - case sqlite.TypeInteger: - row[colName] = stmt.ColumnInt64(i) - case sqlite.TypeFloat: - row[colName] = stmt.ColumnFloat(i) - case sqlite.TypeText: - row[colName] = stmt.ColumnText(i) - case sqlite.TypeBlob: - blobSize := stmt.ColumnLen(i) - if blobSize > 0 { - buf := make([]byte, blobSize) - row[colName] = stmt.ColumnBytes(i, buf) - } else { - row[colName] = []byte{} - } - case sqlite.TypeNull: - row[colName] = nil - } - } - rows = append(rows, row) - return nil - } - - if err := sqlitex.Execute(conn, query, &execOpts); err != nil { - return state.PushError("sqlite.query: %v", err) - } - - if err := state.PushValue(rows); err != nil { - return state.PushError("sqlite.query: %v", err) - } - - return 1 -} - -func sqlExec(state *luajit.State) int { - if err := state.CheckMinArgs(3); err != nil { - return state.PushError("sqlite.exec: %v", err) - } - - dbName, err := state.SafeToString(1) - if err != nil { - return state.PushError("sqlite.exec: database name must be string") - } - - query, err := state.SafeToString(2) - if err != nil { - return state.PushError("sqlite.exec: query must be string") - } - - stateIndex := int(state.ToNumber(-1)) - - conn, err := getStateConnection(stateIndex, dbName) - if err != nil { - return state.PushError("sqlite.exec: %v", err) - } - - hasParams := state.GetTop() >= 4 && !state.IsNil(3) - - if strings.Contains(query, ";") && !hasParams { - if err := sqlitex.ExecScript(conn, query); err != nil { - return state.PushError("sqlite.exec: %v", err) - } - state.PushNumber(float64(conn.Changes())) - return 1 - } - - if !hasParams { - if err := sqlitex.Execute(conn, query, nil); err != nil { - return state.PushError("sqlite.exec: %v", err) - } - state.PushNumber(float64(conn.Changes())) - return 1 - } - - var execOpts sqlitex.ExecOptions - if err := setupParams(state, 3, &execOpts); err != nil { - return state.PushError("sqlite.exec: %v", err) - } - - if err := sqlitex.Execute(conn, query, &execOpts); err != nil { - return state.PushError("sqlite.exec: %v", err) - } - - state.PushNumber(float64(conn.Changes())) - return 1 -} - -func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error { - if state.IsTable(paramIndex) { - paramsAny, err := state.ToTable(paramIndex) - if err != nil { - return fmt.Errorf("invalid parameters: %w", err) - } - - // Handle direct array types - if arrParams, ok := paramsAny.([]any); ok { - execOpts.Args = arrParams - return nil - } - if strArr, ok := paramsAny.([]string); ok { - args := make([]any, len(strArr)) - for i, v := range strArr { - args[i] = v - } - execOpts.Args = args - return nil - } - if floatArr, ok := paramsAny.([]float64); ok { - args := make([]any, len(floatArr)) - for i, v := range floatArr { - args[i] = v - } - execOpts.Args = args - return nil - } - - params, ok := paramsAny.(map[string]any) - if !ok { - return fmt.Errorf("unsupported parameter type: %T", paramsAny) - } - - // Check for array-style parameters (empty string key indicates array) - if arr, ok := params[""]; ok { - if arrParams, ok := arr.([]any); ok { - execOpts.Args = arrParams - } else if floatArr, ok := arr.([]float64); ok { - args := make([]any, len(floatArr)) - for i, v := range floatArr { - args[i] = v - } - execOpts.Args = args - } - } else { - // Named parameters - named := make(map[string]any, len(params)) - for k, v := range params { - if len(k) > 0 && k[0] != ':' { - named[":"+k] = v - } else { - named[k] = v - } - } - execOpts.Named = named - } - } else { - // Multiple individual parameters - count := state.GetTop() - 2 - args := make([]any, count) - for i := range count { - idx := i + 3 - val, err := state.ToValue(idx) - if err != nil { - return fmt.Errorf("invalid parameter %d: %w", i+1, err) - } - args[i] = val - } - execOpts.Args = args - } - - return nil -} - -func sqlGetOne(state *luajit.State) int { - if err := state.CheckMinArgs(3); err != nil { - return state.PushError("sqlite.get_one: %v", err) - } - - dbName, err := state.SafeToString(1) - if err != nil { - return state.PushError("sqlite.get_one: database name must be string") - } - - query, err := state.SafeToString(2) - if err != nil { - return state.PushError("sqlite.get_one: query must be string") - } - - stateIndex := int(state.ToNumber(-1)) - - conn, err := getStateConnection(stateIndex, dbName) - if err != nil { - return state.PushError("sqlite.get_one: %v", err) - } - - var execOpts sqlitex.ExecOptions - var result map[string]any - - // Check if params provided (before state index) - if state.GetTop() >= 4 && !state.IsNil(3) { - if err := setupParams(state, 3, &execOpts); err != nil { - return state.PushError("sqlite.get_one: %v", err) - } - } - - execOpts.ResultFunc = func(stmt *sqlite.Stmt) error { - if result != nil { - return nil - } - - result = make(map[string]any) - colCount := stmt.ColumnCount() - - for i := range colCount { - colName := stmt.ColumnName(i) - switch stmt.ColumnType(i) { - case sqlite.TypeInteger: - result[colName] = stmt.ColumnInt64(i) - case sqlite.TypeFloat: - result[colName] = stmt.ColumnFloat(i) - case sqlite.TypeText: - result[colName] = stmt.ColumnText(i) - case sqlite.TypeBlob: - blobSize := stmt.ColumnLen(i) - if blobSize > 0 { - buf := make([]byte, blobSize) - result[colName] = stmt.ColumnBytes(i, buf) - } else { - result[colName] = []byte{} - } - case sqlite.TypeNull: - result[colName] = nil - } - } - return nil - } - - if err := sqlitex.Execute(conn, query, &execOpts); err != nil { - return state.PushError("sqlite.get_one: %v", err) - } - - if result == nil { - state.PushNil() - } else { - if err := state.PushValue(result); err != nil { - return state.PushError("sqlite.get_one: %v", err) - } - } - - return 1 -} - -// CleanupStateConnection releases all connections for a specific state -func CleanupStateConnection(stateIndex int) { - stateConnsMu.Lock() - defer stateConnsMu.Unlock() - - statePrefix := fmt.Sprintf("%d-", stateIndex) - - for key, sc := range stateConns { - if strings.HasPrefix(key, statePrefix) { - if sc.pool != nil && sc.conn != nil { - sc.pool.Put(sc.conn) - } - delete(stateConns, key) - } - } -} - -func RegisterSQLiteFunctions(state *luajit.State) error { - if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil { - return err - } - if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil { - return err - } - if err := state.RegisterGoFunction("__sqlite_get_one", sqlGetOne); err != nil { - return err - } - return nil -} diff --git a/test.lua b/test.lua new file mode 100644 index 0000000..4fea308 --- /dev/null +++ b/test.lua @@ -0,0 +1,189 @@ +-- Example HTTP server with string-based routing, parameters, and wildcards + +print("Starting Moonshark HTTP server with string routing...") + +-- Start HTTP server +http.listen(3000) + +-- Home page +http.route("GET", "/", function(req) + local visits = session.get("visits") or 0 + visits = visits + 1 + session.set("visits", visits) + + return http.html([[ +

Welcome to Moonshark!

+

You've visited this page ]] .. visits .. [[ times.

+

Login | User Profile

+

API Test | File Access

+ ]]) +end) + +-- User profile with dynamic parameter +http.route("GET", "/users/:id", function(req) + local userId = req.params.id + return http.html([[ +

User Profile

+

User ID: ]] .. userId .. [[

+

View Posts

+

Home

+ ]]) +end) + +-- User posts with multiple parameters +http.route("GET", "/users/:id/posts", function(req) + local userId = req.params.id + return http.json({ + user_id = userId, + posts = { + {id = 1, title = "First Post", content = "Hello world!"}, + {id = 2, title = "Second Post", content = "Learning Lua routing!"} + } + }) +end) + +-- Blog post with slug parameter +http.route("GET", "/blog/:slug", function(req) + local slug = req.params.slug + return http.html([[ +

Blog Post: ]] .. slug .. [[

+

This is the content for blog post "]] .. slug .. [["

+

View Comments

+

Home

+ ]]) +end) + +-- Blog comments +http.route("GET", "/blog/:slug/comments", function(req) + local slug = req.params.slug + return http.json({ + blog_slug = slug, + comments = { + {author = "Alice", comment = "Great post!"}, + {author = "Bob", comment = "Very informative."} + } + }) +end) + +-- Wildcard route for file serving +http.route("GET", "/files/*path", function(req) + local filePath = req.params.path + return http.html([[ +

File Access

+

Requested file: ]] .. filePath .. [[

+

In a real application, this would serve the file content.

+

Home

+ ]]) +end) + +-- API endpoints with parameters +http.route("GET", "/api/users/:id", function(req) + local userId = req.params.id + return http.json({ + id = tonumber(userId), + name = "User " .. userId, + email = "user" .. userId .. "@example.com", + active = true + }) +end) + +http.route("PUT", "/api/users/:id", function(req) + local userId = req.params.id + local userData = req.form + + return http.json({ + success = true, + message = "User " .. userId .. " updated", + data = userData + }) +end) + +http.route("DELETE", "/api/users/:id", function(req) + local userId = req.params.id + return http.json({ + success = true, + message = "User " .. userId .. " deleted" + }) +end) + +-- Login form with CSRF protection +http.route("GET", "/login", function(req) + return http.html([[ +

Login

+
+ ]] .. csrf.field() .. [[ +
+
+ +
+

Home

+ ]]) +end) + +-- Handle login POST +http.route("POST", "/login", function(req) + if not csrf.validate() then + http.status(403) + return "CSRF token invalid" + end + + local username = req.form.username + local password = req.form.password + + if username == "admin" and password == "secret" then + session.set("user", username) + session.flash("success", "Login successful!") + return http.redirect("/dashboard") + else + session.flash("error", "Invalid credentials") + return http.redirect("/login") + end +end) + +-- Dashboard (requires login) +http.route("GET", "/dashboard", function(req) + local user = session.get("user") + if not user then + return http.redirect("/login") + end + + local success = session.get_flash("success") + local error = session.get_flash("error") + + return http.html([[ +

Dashboard

+ ]] .. (success and ("

" .. success .. "

") or "") .. [[ + ]] .. (error and ("

" .. error .. "

") or "") .. [[ +

Welcome, ]] .. user .. [[!

+

Logout

+

Home

+ ]]) +end) + +-- Logout +http.route("GET", "/logout", function(req) + session.set("user", nil) + session.flash("info", "You have been logged out") + return http.redirect("/") +end) + +-- Catch-all route for 404s (must be last) +http.route("GET", "*path", function(req) + http.status(404) + return http.html([[ +

404 - Page Not Found

+

The requested path "]] .. req.params.path .. [[" was not found.

+

Go Home

+ ]]) +end) + +print("Server configured with string routing. Listening on http://localhost:3000") +print("Try these routes:") +print(" GET /") +print(" GET /users/123") +print(" GET /users/456/posts") +print(" GET /blog/my-first-post") +print(" GET /blog/lua-tutorial/comments") +print(" GET /files/docs/readme.txt") +print(" GET /api/users/789") +print(" GET /nonexistent (404 handler)") \ No newline at end of file diff --git a/utils/debug.go b/utils/debug.go deleted file mode 100644 index aeed809..0000000 --- a/utils/debug.go +++ /dev/null @@ -1,217 +0,0 @@ -package utils - -import ( - "fmt" - "html/template" - "runtime" - "strings" - "time" - - "Moonshark/config" - "Moonshark/metadata" -) - -// ComponentStats holds stats from various system components -type ComponentStats struct { - RouteCount int // Number of routes - BytecodeBytes int64 // Total size of bytecode in bytes - ModuleCount int // Number of loaded modules - SessionStats map[string]uint64 // Session cache statistics -} - -// SystemStats represents system statistics for debugging -type SystemStats struct { - Timestamp time.Time - GoVersion string - GoRoutines int - Memory runtime.MemStats - Components ComponentStats - Version string - Config *config.Config -} - -// CollectSystemStats gathers basic system statistics -func CollectSystemStats(cfg *config.Config) SystemStats { - var stats SystemStats - var mem runtime.MemStats - - stats.Timestamp = time.Now() - stats.GoVersion = runtime.Version() - stats.GoRoutines = runtime.NumGoroutine() - stats.Version = metadata.Version - stats.Config = cfg - - runtime.ReadMemStats(&mem) - stats.Memory = mem - - return stats -} - -// DebugStatsPage generates an HTML debug stats page -func DebugStatsPage(stats SystemStats) string { - const debugTemplate = ` - - - - Moonshark - - - -

Moonshark

-
Generated at: {{.Timestamp.Format "2006-01-02 15:04:05"}}
- -
-

Server

-
- - -
Version{{.Version}}
-
-
- -
-

System

-
- - - -
Go Version{{.GoVersion}}
Goroutines{{.GoRoutines}}
-
-
- -
-

Memory

-
- - - - - -
Allocated{{ByteCount .Memory.Alloc}}
Total Allocated{{ByteCount .Memory.TotalAlloc}}
System Memory{{ByteCount .Memory.Sys}}
GC Cycles{{.Memory.NumGC}}
-
-
- -
-

Sessions

-
- - - - - - - -
Active Sessions{{index .Components.SessionStats "entries"}}
Cache Size{{ByteCount (index .Components.SessionStats "bytes")}}
Max Cache Size{{ByteCount (index .Components.SessionStats "max_bytes")}}
Cache Gets{{index .Components.SessionStats "gets"}}
Cache Sets{{index .Components.SessionStats "sets"}}
Cache Misses{{index .Components.SessionStats "misses"}}
-
-
- -
-

LuaRunner

-
- - - - - - -
InterpreterLuaJIT 2.1 (Lua 5.1)
Active Routes{{.Components.RouteCount}}
Bytecode Size{{ByteCount .Components.BytecodeBytes}}
Loaded Modules{{.Components.ModuleCount}}
State Pool Size{{.Config.Runner.PoolSize}}
-
-
- -
-

Config

-
- - - - - - - - - - -
Port{{.Config.Server.Port}}
Pool Size{{.Config.Runner.PoolSize}}
Debug Mode{{.Config.Server.Debug}}
Log Level{{.Config.Server.LogLevel}}
HTTP Logging{{.Config.Server.HTTPLogging}}
Directories -
Routes: {{.Config.Dirs.Routes}}
-
Static: {{.Config.Dirs.Static}}
-
Override: {{.Config.Dirs.Override}}
-
Libs: {{range .Config.Dirs.Libs}}{{.}}, {{end}}
-
-
-
- - -` - - // Create a template function map - funcMap := template.FuncMap{ - "ByteCount": func(b any) string { - var bytes uint64 - - switch v := b.(type) { - case uint64: - bytes = v - case int64: - bytes = uint64(v) - case int: - bytes = uint64(v) - default: - return fmt.Sprintf("%T: %v", b, b) - } - - const unit = 1024 - if bytes < unit { - return fmt.Sprintf("%d B", bytes) - } - div, exp := uint64(unit), 0 - for n := bytes / unit; n >= unit; n /= unit { - div *= unit - exp++ - } - return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp]) - }, - } - - // Parse the template - tmpl, err := template.New("debug").Funcs(funcMap).Parse(debugTemplate) - if err != nil { - return fmt.Sprintf("Error parsing template: %v", err) - } - - // Execute the template - var output strings.Builder - if err := tmpl.Execute(&output, stats); err != nil { - return fmt.Sprintf("Error executing template: %v", err) - } - - return output.String() -} diff --git a/utils/errorPages.go b/utils/errorPages.go deleted file mode 100644 index ba8437d..0000000 --- a/utils/errorPages.go +++ /dev/null @@ -1,231 +0,0 @@ -package utils - -import ( - "math/rand" -) - -var dbg bool - -// ErrorType represents HTTP error types -type ErrorType int - -const ( - ErrorTypeNotFound ErrorType = 404 - ErrorTypeMethodNotAllowed ErrorType = 405 - ErrorTypeInternalError ErrorType = 500 - ErrorTypeForbidden ErrorType = 403 // Added CSRF/Forbidden error type -) - -func Debug(enabled bool) { dbg = enabled } - -// ErrorPage generates an HTML error page based on the error type -// It first checks for an override file, and if not found, generates a default page -func ErrorPage(errorType ErrorType, url string, errMsg string) string { - // No override found, generate default page - switch errorType { - case ErrorTypeNotFound: - return generateNotFoundHTML(url) - case ErrorTypeMethodNotAllowed: - return generateMethodNotAllowedHTML(url) - case ErrorTypeInternalError: - return generateInternalErrorHTML(dbg, url, errMsg) - case ErrorTypeForbidden: - return generateForbiddenHTML(dbg, url, errMsg) - default: - // Fallback to internal error - return generateInternalErrorHTML(dbg, url, errMsg) - } -} - -// NotFoundPage generates a 404 Not Found error page -func NotFoundPage(url string) string { - return ErrorPage(ErrorTypeNotFound, url, "") -} - -// MethodNotAllowedPage generates a 405 Method Not Allowed error page -func MethodNotAllowedPage(url string) string { - return ErrorPage(ErrorTypeMethodNotAllowed, url, "") -} - -// InternalErrorPage generates a 500 Internal Server Error page -func InternalErrorPage(url string, errMsg string) string { - return ErrorPage(ErrorTypeInternalError, url, errMsg) -} - -// ForbiddenPage generates a 403 Forbidden error page -func ForbiddenPage(url string, errMsg string) string { - return ErrorPage(ErrorTypeForbidden, url, errMsg) -} - -// generateInternalErrorHTML creates a 500 Internal Server Error page -func generateInternalErrorHTML(debugMode bool, url string, errMsg string) string { - errorMessages := []string{ - "Oops! Something went wrong", - "Oh no! The server choked", - "Well, this is embarrassing...", - "Houston, we have a problem", - "Gremlins in the system", - "The server is taking a coffee break", - "Moonshark encountered a lunar eclipse", - "Our code monkeys are working on it", - "The server is feeling under the weather", - "500 Brain Not Found", - } - - randomMessage := errorMessages[rand.Intn(len(errorMessages))] - return generateErrorHTML("500", randomMessage, "Internal Server Error", debugMode, errMsg) -} - -// generateForbiddenHTML creates a 403 Forbidden error page -func generateForbiddenHTML(debugMode bool, url string, errMsg string) string { - errorMessages := []string{ - "Access denied", - "You shall not pass", - "This area is off-limits", - "Security check failed", - "Invalid security token", - "Request blocked for security reasons", - "Permission denied", - "Security violation detected", - "This request was rejected", - "Security first, access second", - } - - defaultMsg := "Invalid or missing CSRF token. This could be due to an expired form or a cross-site request forgery attempt." - if errMsg == "" { - errMsg = defaultMsg - } - - randomMessage := errorMessages[rand.Intn(len(errorMessages))] - return generateErrorHTML("403", randomMessage, "Forbidden", debugMode, errMsg) -} - -// generateNotFoundHTML creates a 404 Not Found error page -func generateNotFoundHTML(url string) string { - errorMessages := []string{ - "Nothing to see here", - "This page is on vacation", - "The page is missing in action", - "This page has left the building", - "This page is in another castle", - "Sorry, we can't find that", - "The page you're looking for doesn't exist", - "Lost in space", - "That's a 404", - "Page not found", - } - - randomMessage := errorMessages[rand.Intn(len(errorMessages))] - return generateErrorHTML("404", randomMessage, "Page Not Found", false, url) -} - -// generateMethodNotAllowedHTML creates a 405 Method Not Allowed error page -func generateMethodNotAllowedHTML(url string) string { - errorMessages := []string{ - "That's not how this works", - "Method not allowed", - "Wrong way!", - "This method is not supported", - "You can't do that here", - "Sorry, wrong door", - "That method won't work here", - "Try a different approach", - "Access denied for this method", - "Method mismatch", - } - - randomMessage := errorMessages[rand.Intn(len(errorMessages))] - return generateErrorHTML("405", randomMessage, "Method Not Allowed", false, url) -} - -// generateErrorHTML creates the common HTML structure for error pages -func generateErrorHTML(errorCode, mainMessage, subMessage string, showDebugInfo bool, codeContent string) string { - errorHTML := ` - - - ` + errorCode + ` - - - -
-

` + errorCode + `

-

` + mainMessage + `

-
` + subMessage + `
` - - if codeContent != "" { - errorHTML += ` - ` + codeContent + `` - } - - // Add a note for debug mode - if showDebugInfo { - errorHTML += ` -

- An error occurred while processing your request.
- Please check the server logs for details. -

` - } - - errorHTML += ` -
- -` - - return errorHTML -} diff --git a/utils/formData.go b/utils/formData.go deleted file mode 100644 index e29e1d4..0000000 --- a/utils/formData.go +++ /dev/null @@ -1,115 +0,0 @@ -package utils - -import ( - "mime/multipart" - "strings" - - "github.com/goccy/go-json" - - "github.com/valyala/fasthttp" -) - -func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) { - contentType := string(ctx.Request.Header.ContentType()) - formData := make(map[string]any) - - switch { - case strings.Contains(contentType, "multipart/form-data"): - if err := parseMultipartInto(ctx, formData); err != nil { - return nil, err - } - - case strings.Contains(contentType, "application/x-www-form-urlencoded"): - args := ctx.PostArgs() - args.VisitAll(func(key, value []byte) { - appendValue(formData, string(key), string(value)) - }) - - case strings.Contains(contentType, "application/json"): - if err := json.Unmarshal(ctx.PostBody(), &formData); err != nil { - return nil, err - } - - default: - // Leave formData empty if content-type is unrecognized - } - - return formData, nil -} - -func parseMultipartInto(ctx *fasthttp.RequestCtx, formData map[string]any) error { - form, err := ctx.MultipartForm() - if err != nil { - return err - } - - for key, values := range form.Value { - if len(values) == 1 { - formData[key] = values[0] - } else if len(values) > 1 { - formData[key] = values - } - } - - if len(form.File) > 0 { - files := make(map[string]any, len(form.File)) - for fieldName, fileHeaders := range form.File { - if len(fileHeaders) == 1 { - files[fieldName] = fileInfoToMap(fileHeaders[0]) - } else { - fileInfos := make([]map[string]any, len(fileHeaders)) - for i, fh := range fileHeaders { - fileInfos[i] = fileInfoToMap(fh) - } - files[fieldName] = fileInfos - } - } - formData["_files"] = files - } - - return nil -} - -func appendValue(formData map[string]any, key, value string) { - if existing, exists := formData[key]; exists { - switch v := existing.(type) { - case string: - formData[key] = []string{v, value} - case []string: - formData[key] = append(v, value) - } - } else { - formData[key] = value - } -} - -func fileInfoToMap(fh *multipart.FileHeader) map[string]any { - ct := fh.Header.Get("Content-Type") - if ct == "" { - ct = getMimeType(fh.Filename) - } - - return map[string]any{ - "filename": fh.Filename, - "size": fh.Size, - "mimetype": ct, - } -} - -func getMimeType(filename string) string { - if i := strings.LastIndex(filename, "."); i >= 0 { - switch filename[i:] { - case ".pdf": - return "application/pdf" - case ".png": - return "image/png" - case ".jpg", ".jpeg": - return "image/jpeg" - case ".gif": - return "image/gif" - case ".svg": - return "image/svg+xml" - } - } - return "application/octet-stream" -} diff --git a/watchers/manager.go b/watchers/manager.go deleted file mode 100644 index 93aa9c4..0000000 --- a/watchers/manager.go +++ /dev/null @@ -1,92 +0,0 @@ -package watchers - -import ( - "errors" - "sync" - - "Moonshark/logger" -) - -var ( - ErrDirectoryNotFound = errors.New("directory not found") - ErrAlreadyWatching = errors.New("already watching directory") -) - -// WatcherManager now just manages watcher lifecycle - no polling logic -type WatcherManager struct { - watchers map[string]*Watcher - mu sync.RWMutex -} - -func NewWatcherManager() *WatcherManager { - return &WatcherManager{ - watchers: make(map[string]*Watcher), - } -} - -func (m *WatcherManager) Close() error { - m.mu.Lock() - defer m.mu.Unlock() - - for _, watcher := range m.watchers { - watcher.Close() - } - - m.watchers = make(map[string]*Watcher) - return nil -} - -func (m *WatcherManager) WatchDirectory(config WatcherConfig) error { - m.mu.Lock() - defer m.mu.Unlock() - - if _, exists := m.watchers[config.Dir]; exists { - return ErrAlreadyWatching - } - - watcher, err := NewWatcher(config) - if err != nil { - return err - } - - m.watchers[config.Dir] = watcher - logger.Debugf("WatcherManager added watcher for %s", config.Dir) - - return nil -} - -func (m *WatcherManager) UnwatchDirectory(dir string) error { - m.mu.Lock() - defer m.mu.Unlock() - - watcher, exists := m.watchers[dir] - if !exists { - return ErrDirectoryNotFound - } - - watcher.Close() - delete(m.watchers, dir) - logger.Debugf("WatcherManager removed watcher for %s", dir) - - return nil -} - -func (m *WatcherManager) GetWatcher(dir string) (*Watcher, bool) { - m.mu.RLock() - defer m.mu.RUnlock() - - watcher, exists := m.watchers[dir] - return watcher, exists -} - -func (m *WatcherManager) ListWatching() []string { - m.mu.RLock() - defer m.mu.RUnlock() - - dirs := make([]string, 0, len(m.watchers)) - for dir := range m.watchers { - dirs = append(dirs, dir) - } - - return dirs -} diff --git a/watchers/watcher.go b/watchers/watcher.go deleted file mode 100644 index a6ff037..0000000 --- a/watchers/watcher.go +++ /dev/null @@ -1,236 +0,0 @@ -package watchers - -import ( - "fmt" - "os" - "path/filepath" - "sync" - "time" - - "Moonshark/logger" -) - -const ( - defaultDebounceTime = 300 * time.Millisecond - defaultPollInterval = 1 * time.Second -) - -type FileChange struct { - Path string - IsNew bool - IsDeleted bool -} - -type FileInfo struct { - ModTime time.Time -} - -// Watcher is now self-contained and manages its own polling -type Watcher struct { - dir string - files map[string]FileInfo - filesMu sync.RWMutex - callback func([]FileChange) error - debounceTime time.Duration - pollInterval time.Duration - recursive bool - - // Self-management - done chan struct{} - debounceTimer *time.Timer - debouncing bool - debounceMu sync.Mutex - wg sync.WaitGroup -} - -type WatcherConfig struct { - Dir string - Callback func([]FileChange) error - DebounceTime time.Duration - PollInterval time.Duration - Recursive bool -} - -func NewWatcher(config WatcherConfig) (*Watcher, error) { - if config.DebounceTime == 0 { - config.DebounceTime = defaultDebounceTime - } - if config.PollInterval == 0 { - config.PollInterval = defaultPollInterval - } - - w := &Watcher{ - dir: config.Dir, - files: make(map[string]FileInfo), - callback: config.Callback, - debounceTime: config.DebounceTime, - pollInterval: config.PollInterval, - recursive: config.Recursive, - done: make(chan struct{}), - } - - if err := w.scanDirectory(); err != nil { - return nil, err - } - - w.wg.Add(1) - go w.watchLoop() - - return w, nil -} - -func (w *Watcher) Close() { - close(w.done) - w.wg.Wait() - - w.debounceMu.Lock() - if w.debounceTimer != nil { - w.debounceTimer.Stop() - } - w.debounceMu.Unlock() -} - -func (w *Watcher) GetDir() string { - return w.dir -} - -// watchLoop is the main polling loop for this watcher -func (w *Watcher) watchLoop() { - defer w.wg.Done() - ticker := time.NewTicker(w.pollInterval) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - w.checkAndNotify() - case <-w.done: - return - } - } -} - -// checkAndNotify combines change detection and notification -func (w *Watcher) checkAndNotify() { - changed, changedFiles := w.detectChanges() - if changed { - w.notifyChange(changedFiles) - } -} - -// detectChanges scans directory and returns changes -func (w *Watcher) detectChanges() (bool, []FileChange) { - newFiles := make(map[string]FileInfo) - var changedFiles []FileChange - changed := false - - err := filepath.Walk(w.dir, func(path string, info os.FileInfo, err error) error { - if err != nil { - return nil - } - - if !w.recursive && info.IsDir() && path != w.dir { - return filepath.SkipDir - } - - currentInfo := FileInfo{ModTime: info.ModTime()} - newFiles[path] = currentInfo - - w.filesMu.RLock() - prevInfo, exists := w.files[path] - w.filesMu.RUnlock() - - if !exists { - changed = true - changedFiles = append(changedFiles, FileChange{Path: path, IsNew: true}) - w.logDebug("File added: %s", path) - } else if currentInfo.ModTime != prevInfo.ModTime { - changed = true - changedFiles = append(changedFiles, FileChange{Path: path}) - w.logDebug("File changed: %s", path) - } - - return nil - }) - - if err != nil { - w.logError("Error scanning directory: %v", err) - return false, nil - } - - // Check for deletions - w.filesMu.RLock() - for path := range w.files { - if _, exists := newFiles[path]; !exists { - changed = true - changedFiles = append(changedFiles, FileChange{Path: path, IsDeleted: true}) - w.logDebug("File deleted: %s", path) - } - } - w.filesMu.RUnlock() - - if changed { - w.filesMu.Lock() - w.files = newFiles - w.filesMu.Unlock() - } - - return changed, changedFiles -} - -func (w *Watcher) scanDirectory() error { - w.filesMu.Lock() - defer w.filesMu.Unlock() - - w.files = make(map[string]FileInfo) - - return filepath.Walk(w.dir, func(path string, info os.FileInfo, err error) error { - if err != nil { - return nil - } - - if !w.recursive && info.IsDir() && path != w.dir { - return filepath.SkipDir - } - - w.files[path] = FileInfo{ModTime: info.ModTime()} - return nil - }) -} - -func (w *Watcher) notifyChange(changedFiles []FileChange) { - w.debounceMu.Lock() - defer w.debounceMu.Unlock() - - if w.debouncing && w.debounceTimer != nil { - w.debounceTimer.Stop() - } - w.debouncing = true - - // Copy to avoid race conditions - filesCopy := make([]FileChange, len(changedFiles)) - copy(filesCopy, changedFiles) - - w.debounceTimer = time.AfterFunc(w.debounceTime, func() { - var err error - if w.callback != nil { - err = w.callback(filesCopy) - } - - if err != nil { - w.logError("Callback error: %v", err) - } - - w.debounceMu.Lock() - w.debouncing = false - w.debounceMu.Unlock() - }) -} - -func (w *Watcher) logDebug(format string, args ...any) { - logger.Debugf("[Watcher] [%s] %s", w.dir, fmt.Sprintf(format, args...)) -} - -func (w *Watcher) logError(format string, args ...any) { - logger.Errorf("[Watcher] [%s] %s", w.dir, fmt.Sprintf(format, args...)) -}