first attempt
This commit is contained in:
parent
c53c54a5d9
commit
97e3ec6547
@ -1,21 +0,0 @@
|
||||
server {
|
||||
port 3117
|
||||
debug false
|
||||
http_logging false
|
||||
static_prefix "public"
|
||||
}
|
||||
|
||||
runner {
|
||||
pool_size 0 -- 0 defaults to GOMAXPROCS
|
||||
}
|
||||
|
||||
dirs = {
|
||||
routes "routes"
|
||||
static "public"
|
||||
fs "fs"
|
||||
data "data"
|
||||
override "override"
|
||||
libs {
|
||||
"libs"
|
||||
}
|
||||
}
|
@ -1,55 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
fin "git.sharkk.net/Sharkk/Fin"
|
||||
)
|
||||
|
||||
// Config represents the current loaded configuration for the server
|
||||
type Config struct {
|
||||
Server struct {
|
||||
Port int
|
||||
Debug bool
|
||||
HTTPLogging bool
|
||||
PublicPrefix string
|
||||
}
|
||||
|
||||
Runner struct {
|
||||
PoolSize int
|
||||
}
|
||||
|
||||
Dirs struct {
|
||||
Routes string
|
||||
Public string
|
||||
FS string
|
||||
Data string
|
||||
Override string
|
||||
Libs []string
|
||||
}
|
||||
|
||||
data *fin.Data
|
||||
}
|
||||
|
||||
// NewConfig creates a new configuration with default values
|
||||
func New(data *fin.Data) *Config {
|
||||
config := &Config{
|
||||
data: data,
|
||||
}
|
||||
|
||||
config.Server.Port = data.GetOr("server.port", 3117).(int)
|
||||
config.Server.Debug = data.GetOr("server.debug", false).(bool)
|
||||
config.Server.HTTPLogging = data.GetOr("server.http_logging", true).(bool)
|
||||
config.Server.PublicPrefix = data.GetOr("server.public_prefix", "public").(string)
|
||||
|
||||
config.Runner.PoolSize = data.GetOr("runner.pool_size", runtime.GOMAXPROCS(0)).(int)
|
||||
|
||||
config.Dirs.Routes = data.GetOr("dirs.routes", "routes").(string)
|
||||
config.Dirs.Public = data.GetOr("dirs.public", "public").(string)
|
||||
config.Dirs.FS = data.GetOr("dirs.fs", "fs").(string)
|
||||
config.Dirs.Data = data.GetOr("dirs.data", "data").(string)
|
||||
config.Dirs.Override = data.GetOr("dirs.override", "override").(string)
|
||||
config.Dirs.Libs = data.GetOr("dirs.libs", []string{"libs"}).([]string)
|
||||
|
||||
return config
|
||||
}
|
35
go.mod
35
go.mod
@ -3,35 +3,20 @@ module Moonshark
|
||||
go 1.24.1
|
||||
|
||||
require (
|
||||
git.sharkk.net/Go/Color v1.1.0
|
||||
git.sharkk.net/Go/LRU v1.0.0
|
||||
git.sharkk.net/Sharkk/Fin v1.3.0
|
||||
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.3
|
||||
github.com/VictoriaMetrics/fastcache v1.12.4
|
||||
github.com/alexedwards/argon2id v1.0.0
|
||||
github.com/deneonet/benc v1.1.8
|
||||
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.4
|
||||
github.com/goccy/go-json v0.10.5
|
||||
github.com/golang/snappy v1.0.0
|
||||
github.com/matoous/go-nanoid/v2 v2.1.0
|
||||
github.com/valyala/bytebufferpool v1.0.0
|
||||
github.com/valyala/fasthttp v1.62.0
|
||||
zombiezen.com/go/sqlite v1.4.2
|
||||
github.com/valyala/fasthttp v1.63.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.1.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/VictoriaMetrics/fastcache v1.12.5 // indirect
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/deneonet/benc v1.1.8 // indirect
|
||||
github.com/golang/snappy v0.0.4 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
golang.org/x/crypto v0.38.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250531010427-b6e5de432a8b // indirect
|
||||
github.com/matoous/go-nanoid/v2 v2.1.0 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
modernc.org/libc v1.65.8 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.37.1 // indirect
|
||||
)
|
||||
|
129
go.sum
129
go.sum
@ -1,133 +1,32 @@
|
||||
git.sharkk.net/Go/Color v1.1.0 h1:1eyUwlcerJLo9/42GSnQxOY84/Htdwz/QTu3FGgLEXk=
|
||||
git.sharkk.net/Go/Color v1.1.0/go.mod h1:cyZFLbUh+GkpsIABVxb5w9EZM+FPj+q9GkCsoECaeTI=
|
||||
git.sharkk.net/Go/LRU v1.0.0 h1:/KqdRVhHldi23aVfQZ4ss6vhCWZqA3vFiQyf1MJPpQc=
|
||||
git.sharkk.net/Go/LRU v1.0.0/go.mod h1:8tdTyl85mss9a+KKwo+Wj9gKHOizhfLfpJhz1ltYz50=
|
||||
git.sharkk.net/Sharkk/Fin v1.3.0 h1:6/f7+h382jJOeo09cgdzH+PGb5XdvajZZRiES52sBkI=
|
||||
git.sharkk.net/Sharkk/Fin v1.3.0/go.mod h1:ca0Ej9yCM/vHh1o3YMvBZspme3EtbwoEL2UXN5UPXMo=
|
||||
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.3 h1:SuLz4X/k+sMy+Uj1lhEy6brJtvtzHLdivUcu5K91y+o=
|
||||
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.3/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8=
|
||||
github.com/VictoriaMetrics/fastcache v1.12.4 h1:2xvmwZBW+9QtHsXggfzAZRs1FZWCsBs8QDg22bMidf0=
|
||||
github.com/VictoriaMetrics/fastcache v1.12.4/go.mod h1:K+JGPBn0sueFlLjZ8rcVM0cKkWKNElKyQXmw57QOoYI=
|
||||
github.com/alexedwards/argon2id v1.0.0 h1:wJzDx66hqWX7siL/SRUmgz3F8YMrd/nfX/xHHcQQP0w=
|
||||
github.com/alexedwards/argon2id v1.0.0/go.mod h1:tYKkqIjzXvZdzPvADMWOEZ+l6+BD6CtBXMj5fnJppiw=
|
||||
github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156 h1:eMwmnE/GDgah4HI848JfFxHt+iPb26b4zyfspmqY0/8=
|
||||
github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM=
|
||||
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.4 h1:j83C2pzDBaVP4FPpyBzP4Dch61plzIko/t7zFvjOK2I=
|
||||
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.4/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8=
|
||||
github.com/VictoriaMetrics/fastcache v1.12.5 h1:966OX9JjqYmDAFdp3wEXLwzukiHIm+GVlZHv6B8KW3k=
|
||||
github.com/VictoriaMetrics/fastcache v1.12.5/go.mod h1:K+JGPBn0sueFlLjZ8rcVM0cKkWKNElKyQXmw57QOoYI=
|
||||
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
|
||||
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
|
||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/deneonet/benc v1.1.8 h1:Qk9diyH0UcnduvCrZ62mBrwUeSZzte4kQxMbclVdhW4=
|
||||
github.com/deneonet/benc v1.1.8/go.mod h1:UCfkM5Od0B2huwv/ZItvtUb7QnALFt9YXtX8NXX4Lts=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
|
||||
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
|
||||
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/matoous/go-nanoid/v2 v2.1.0 h1:P64+dmq21hhWdtvZfEAofnvJULaRR1Yib0+PnU669bE=
|
||||
github.com/matoous/go-nanoid/v2 v2.1.0/go.mod h1:KlbGNQ+FhrUNIHUxZdL63t7tl4LaPkZNpUULS8H4uVM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.62.0 h1:8dKRBX/y2rCzyc6903Zu1+3qN0H/d2MsxPPmVNamiH0=
|
||||
github.com/valyala/fasthttp v1.62.0/go.mod h1:FCINgr4GKdKqV8Q0xv8b+UxPV+H/O5nNFo3D+r54Htg=
|
||||
github.com/valyala/fasthttp v1.63.0 h1:DisIL8OjB7ul2d7cBaMRcKTQDYnrGy56R4FCiuDP0Ns=
|
||||
github.com/valyala/fasthttp v1.63.0/go.mod h1:REc4IeW+cAEyLrRPa5A81MIjvz0QE1laoTX2EaPHKJM=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
|
||||
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
|
||||
golang.org/x/exp v0.0.0-20250531010427-b6e5de432a8b h1:QoALfVG9rhQ/M7vYDScfPdWjGL9dlsVVM5VGh7aKoAA=
|
||||
golang.org/x/exp v0.0.0-20250531010427-b6e5de432a8b/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
|
||||
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
|
||||
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ=
|
||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
|
||||
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
|
||||
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
modernc.org/cc/v4 v4.26.1 h1:+X5NtzVBn0KgsBCBe+xkDC7twLb/jNVj9FPgiwSQO3s=
|
||||
modernc.org/cc/v4 v4.26.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU=
|
||||
modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE=
|
||||
modernc.org/fileutil v1.3.1 h1:8vq5fe7jdtEvoCf3Zf9Nm0Q05sH6kGx0Op2CPx1wTC8=
|
||||
modernc.org/fileutil v1.3.1/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||
modernc.org/libc v1.65.8 h1:7PXRJai0TXZ8uNA3srsmYzmTyrLoHImV5QxHeni108Q=
|
||||
modernc.org/libc v1.65.8/go.mod h1:011EQibzzio/VX3ygj1qGFt5kMjP0lHb0qCW5/D/pQU=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.37.1 h1:EgHJK/FPoqC+q2YBXg7fUmES37pCHFc97sI7zSayBEs=
|
||||
modernc.org/sqlite v1.37.1/go.mod h1:XwdRtsE1MpiBcL54+MbKcaDvcuej+IYSMfLN6gSKV8g=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||
zombiezen.com/go/sqlite v1.4.2 h1:KZXLrBuJ7tKNEm+VJcApLMeQbhmAUOKA5VWS93DfFRo=
|
||||
zombiezen.com/go/sqlite v1.4.2/go.mod h1:5Kd4taTAD4MkBzT25mQ9uaAlLjyR0rFhsR6iINO70jc=
|
||||
|
756
http/http.go
Normal file
756
http/http.go
Normal file
@ -0,0 +1,756 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
_ "embed"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"Moonshark/http/router"
|
||||
"Moonshark/http/sessions"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
//go:embed http.lua
|
||||
var httpLuaCode string
|
||||
|
||||
type Server struct {
|
||||
server *fasthttp.Server
|
||||
router *router.Router
|
||||
sessions *sessions.SessionManager
|
||||
state *luajit.State
|
||||
stateMu sync.Mutex
|
||||
funcCounter int
|
||||
}
|
||||
|
||||
type RequestContext struct {
|
||||
Method string
|
||||
Path string
|
||||
Headers map[string]string
|
||||
Query map[string]string
|
||||
Form map[string]any
|
||||
Cookies map[string]string
|
||||
Session *sessions.Session
|
||||
Body string
|
||||
Params map[string]string
|
||||
}
|
||||
|
||||
var globalServer *Server
|
||||
|
||||
func NewServer(state *luajit.State) *Server {
|
||||
return &Server{
|
||||
router: router.New(),
|
||||
sessions: sessions.NewSessionManager(10000),
|
||||
state: state,
|
||||
}
|
||||
}
|
||||
|
||||
func RegisterHTTPFunctions(L *luajit.State) error {
|
||||
globalServer = NewServer(L)
|
||||
|
||||
functions := map[string]luajit.GoFunction{
|
||||
"__http_listen": globalServer.httpListen,
|
||||
"__http_route": globalServer.httpRoute,
|
||||
"__http_set_status": httpSetStatus,
|
||||
"__http_set_header": httpSetHeader,
|
||||
"__http_redirect": httpRedirect,
|
||||
"__session_get": globalServer.sessionGet,
|
||||
"__session_set": globalServer.sessionSet,
|
||||
"__session_flash": globalServer.sessionFlash,
|
||||
"__session_get_flash": globalServer.sessionGetFlash,
|
||||
"__cookie_set": cookieSet,
|
||||
"__cookie_get": cookieGet,
|
||||
"__csrf_generate": globalServer.csrfGenerate,
|
||||
"__csrf_validate": globalServer.csrfValidate,
|
||||
}
|
||||
|
||||
for name, fn := range functions {
|
||||
if err := L.RegisterGoFunction(name, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return L.DoString(httpLuaCode)
|
||||
}
|
||||
|
||||
func (s *Server) httpListen(state *luajit.State) int {
|
||||
port, err := state.SafeToNumber(1)
|
||||
if err != nil {
|
||||
return state.PushError("listen: port must be number")
|
||||
}
|
||||
|
||||
s.server = &fasthttp.Server{
|
||||
Handler: s.requestHandler,
|
||||
Name: "Moonshark/1.0",
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf(":%d", int(port))
|
||||
go func() {
|
||||
if err := s.server.ListenAndServe(addr); err != nil {
|
||||
fmt.Printf("Server error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
fmt.Printf("Server listening on port %d\n", int(port))
|
||||
state.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s *Server) httpRoute(state *luajit.State) int {
|
||||
method, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("route: method must be string")
|
||||
}
|
||||
|
||||
path, err := state.SafeToString(2)
|
||||
if err != nil {
|
||||
return state.PushError("route: path must be string")
|
||||
}
|
||||
|
||||
if !state.IsFunction(3) {
|
||||
return state.PushError("route: handler must be function")
|
||||
}
|
||||
|
||||
// Store function and get reference
|
||||
state.PushCopy(3)
|
||||
funcRef := s.storeFunction()
|
||||
|
||||
// Add route to router
|
||||
if err := s.router.AddRoute(strings.ToUpper(method), path, funcRef); err != nil {
|
||||
return state.PushError("route: failed to add route: %s", err.Error())
|
||||
}
|
||||
|
||||
state.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s *Server) storeFunction() int {
|
||||
s.state.GetGlobal("__moonshark_functions")
|
||||
if s.state.IsNil(-1) {
|
||||
s.state.Pop(1)
|
||||
s.state.NewTable()
|
||||
s.state.PushCopy(-1)
|
||||
s.state.SetGlobal("__moonshark_functions")
|
||||
}
|
||||
|
||||
s.funcCounter++
|
||||
s.state.PushNumber(float64(s.funcCounter))
|
||||
s.state.PushCopy(-3)
|
||||
s.state.SetTable(-3)
|
||||
s.state.Pop(2)
|
||||
|
||||
return s.funcCounter
|
||||
}
|
||||
|
||||
func (s *Server) getFunction(ref int) bool {
|
||||
s.state.GetGlobal("__moonshark_functions")
|
||||
if s.state.IsNil(-1) {
|
||||
s.state.Pop(1)
|
||||
return false
|
||||
}
|
||||
|
||||
s.state.PushNumber(float64(ref))
|
||||
s.state.GetTable(-2)
|
||||
isFunc := s.state.IsFunction(-1)
|
||||
if !isFunc {
|
||||
s.state.Pop(2)
|
||||
return false
|
||||
}
|
||||
|
||||
s.state.Remove(-2)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) requestHandler(ctx *fasthttp.RequestCtx) {
|
||||
method := string(ctx.Method())
|
||||
path := string(ctx.Path())
|
||||
|
||||
// Look up route in router
|
||||
handlerRef, params, found := s.router.Lookup(method, path)
|
||||
if !found {
|
||||
ctx.SetStatusCode(404)
|
||||
ctx.SetBodyString("Not Found")
|
||||
return
|
||||
}
|
||||
|
||||
reqCtx := s.buildRequestContext(ctx, params)
|
||||
reqCtx.Session.AdvanceFlash()
|
||||
|
||||
s.stateMu.Lock()
|
||||
defer s.stateMu.Unlock()
|
||||
|
||||
s.setupRequestEnvironment(reqCtx)
|
||||
|
||||
if !s.getFunction(handlerRef) {
|
||||
ctx.SetStatusCode(500)
|
||||
ctx.SetBodyString("Handler not found")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.state.PushValue(s.requestToTable(reqCtx)); err != nil {
|
||||
ctx.SetStatusCode(500)
|
||||
ctx.SetBodyString("Failed to create request object")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.state.Call(1, 1); err != nil {
|
||||
ctx.SetStatusCode(500)
|
||||
ctx.SetBodyString(fmt.Sprintf("Handler error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
var responseBody string
|
||||
if s.state.GetTop() > 0 && !s.state.IsNil(-1) {
|
||||
responseBody = s.state.ToString(-1)
|
||||
s.state.Pop(1)
|
||||
}
|
||||
|
||||
s.updateSessionFromLua(reqCtx.Session)
|
||||
s.applyResponse(ctx, responseBody)
|
||||
s.sessions.ApplySessionCookie(ctx, reqCtx.Session)
|
||||
}
|
||||
|
||||
func (s *Server) setupRequestEnvironment(reqCtx *RequestContext) {
|
||||
s.state.PushValue(s.requestToTable(reqCtx))
|
||||
s.state.SetGlobal("__request")
|
||||
|
||||
s.state.PushValue(s.sessionToTable(reqCtx.Session))
|
||||
s.state.SetGlobal("__session")
|
||||
|
||||
s.state.NewTable()
|
||||
s.state.SetGlobal("__response")
|
||||
}
|
||||
|
||||
func (s *Server) requestToTable(reqCtx *RequestContext) map[string]any {
|
||||
return map[string]any{
|
||||
"method": reqCtx.Method,
|
||||
"path": reqCtx.Path,
|
||||
"headers": reqCtx.Headers,
|
||||
"query": reqCtx.Query,
|
||||
"form": reqCtx.Form,
|
||||
"cookies": reqCtx.Cookies,
|
||||
"body": reqCtx.Body,
|
||||
"params": reqCtx.Params,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) sessionToTable(session *sessions.Session) map[string]any {
|
||||
return map[string]any{
|
||||
"id": session.ID,
|
||||
"data": session.GetAll(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) updateSessionFromLua(session *sessions.Session) {
|
||||
s.state.GetGlobal("__session")
|
||||
if s.state.IsNil(-1) {
|
||||
s.state.Pop(1)
|
||||
return
|
||||
}
|
||||
|
||||
s.state.GetField(-1, "data")
|
||||
if s.state.IsTable(-1) {
|
||||
if data, err := s.state.ToTable(-1); err == nil {
|
||||
if dataMap, ok := data.(map[string]any); ok {
|
||||
session.Clear()
|
||||
for k, v := range dataMap {
|
||||
session.Set(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
s.state.Pop(2)
|
||||
}
|
||||
|
||||
func (s *Server) applyResponse(ctx *fasthttp.RequestCtx, body string) {
|
||||
s.state.GetGlobal("__response")
|
||||
if s.state.IsNil(-1) {
|
||||
s.state.Pop(1)
|
||||
if body != "" {
|
||||
ctx.SetBodyString(body)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
s.state.GetField(-1, "status")
|
||||
if s.state.IsNumber(-1) {
|
||||
ctx.SetStatusCode(int(s.state.ToNumber(-1)))
|
||||
}
|
||||
s.state.Pop(1)
|
||||
|
||||
s.state.GetField(-1, "headers")
|
||||
if s.state.IsTable(-1) {
|
||||
s.state.ForEachTableKV(-1, func(key, value string) bool {
|
||||
ctx.Response.Header.Set(key, value)
|
||||
return true
|
||||
})
|
||||
}
|
||||
s.state.Pop(1)
|
||||
|
||||
s.state.GetField(-1, "cookies")
|
||||
if s.state.IsTable(-1) {
|
||||
s.applyCookies(ctx)
|
||||
}
|
||||
s.state.Pop(1)
|
||||
|
||||
s.state.Pop(1)
|
||||
|
||||
if body != "" {
|
||||
ctx.SetBodyString(body)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) applyCookies(ctx *fasthttp.RequestCtx) {
|
||||
s.state.ForEachArray(-1, func(i int, state *luajit.State) bool {
|
||||
if !state.IsTable(-1) {
|
||||
return true
|
||||
}
|
||||
|
||||
name := state.GetFieldString(-1, "name", "")
|
||||
value := state.GetFieldString(-1, "value", "")
|
||||
if name == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
cookie := fasthttp.AcquireCookie()
|
||||
defer fasthttp.ReleaseCookie(cookie)
|
||||
|
||||
cookie.SetKey(name)
|
||||
cookie.SetValue(value)
|
||||
cookie.SetPath(state.GetFieldString(-1, "path", "/"))
|
||||
|
||||
if domain := state.GetFieldString(-1, "domain", ""); domain != "" {
|
||||
cookie.SetDomain(domain)
|
||||
}
|
||||
|
||||
if state.GetFieldBool(-1, "secure", false) {
|
||||
cookie.SetSecure(true)
|
||||
}
|
||||
|
||||
if state.GetFieldBool(-1, "http_only", true) {
|
||||
cookie.SetHTTPOnly(true)
|
||||
}
|
||||
|
||||
if maxAge := state.GetFieldNumber(-1, "max_age", 0); maxAge > 0 {
|
||||
cookie.SetExpire(time.Now().Add(time.Duration(maxAge) * time.Second))
|
||||
}
|
||||
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) buildRequestContext(ctx *fasthttp.RequestCtx, params *router.Params) *RequestContext {
|
||||
reqCtx := &RequestContext{
|
||||
Method: string(ctx.Method()),
|
||||
Path: string(ctx.Path()),
|
||||
Headers: make(map[string]string),
|
||||
Query: make(map[string]string),
|
||||
Cookies: make(map[string]string),
|
||||
Body: string(ctx.PostBody()),
|
||||
Params: make(map[string]string),
|
||||
}
|
||||
|
||||
ctx.Request.Header.VisitAll(func(key, value []byte) {
|
||||
reqCtx.Headers[string(key)] = string(value)
|
||||
})
|
||||
|
||||
ctx.Request.Header.VisitAllCookie(func(key, value []byte) {
|
||||
reqCtx.Cookies[string(key)] = string(value)
|
||||
})
|
||||
|
||||
ctx.QueryArgs().VisitAll(func(key, value []byte) {
|
||||
reqCtx.Query[string(key)] = string(value)
|
||||
})
|
||||
|
||||
// Convert router params to map
|
||||
for i, key := range params.Keys {
|
||||
if i < len(params.Values) {
|
||||
reqCtx.Params[key] = params.Values[i]
|
||||
}
|
||||
}
|
||||
|
||||
reqCtx.Form = s.parseForm(ctx)
|
||||
reqCtx.Session = s.sessions.GetSessionFromRequest(ctx)
|
||||
|
||||
return reqCtx
|
||||
}
|
||||
|
||||
func (s *Server) parseForm(ctx *fasthttp.RequestCtx) map[string]any {
|
||||
contentType := string(ctx.Request.Header.ContentType())
|
||||
form := make(map[string]any)
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(ctx.PostBody(), &data); err == nil {
|
||||
return data
|
||||
}
|
||||
} else if strings.Contains(contentType, "application/x-www-form-urlencoded") {
|
||||
ctx.PostArgs().VisitAll(func(key, value []byte) {
|
||||
form[string(key)] = string(value)
|
||||
})
|
||||
} else if strings.Contains(contentType, "multipart/form-data") {
|
||||
if multipartForm, err := ctx.MultipartForm(); err == nil {
|
||||
for key, values := range multipartForm.Value {
|
||||
if len(values) == 1 {
|
||||
form[key] = values[0]
|
||||
} else {
|
||||
form[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
if len(multipartForm.File) > 0 {
|
||||
files := make(map[string]any)
|
||||
for fieldName, fileHeaders := range multipartForm.File {
|
||||
if len(fileHeaders) == 1 {
|
||||
files[fieldName] = s.fileToMap(fileHeaders[0])
|
||||
} else {
|
||||
fileList := make([]map[string]any, len(fileHeaders))
|
||||
for i, fh := range fileHeaders {
|
||||
fileList[i] = s.fileToMap(fh)
|
||||
}
|
||||
files[fieldName] = fileList
|
||||
}
|
||||
}
|
||||
form["_files"] = files
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return form
|
||||
}
|
||||
|
||||
func (s *Server) fileToMap(fh *multipart.FileHeader) map[string]any {
|
||||
return map[string]any{
|
||||
"filename": fh.Filename,
|
||||
"size": fh.Size,
|
||||
"mimetype": fh.Header.Get("Content-Type"),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) generateCSRFToken() string {
|
||||
bytes := make([]byte, 32)
|
||||
rand.Read(bytes)
|
||||
return base64.URLEncoding.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// Lua function implementations
|
||||
func httpSetStatus(state *luajit.State) int {
|
||||
code, _ := state.SafeToNumber(1)
|
||||
state.GetGlobal("__response")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.SetGlobal("__response")
|
||||
state.GetGlobal("__response")
|
||||
}
|
||||
state.PushNumber(code)
|
||||
state.SetField(-2, "status")
|
||||
state.Pop(1)
|
||||
return 0
|
||||
}
|
||||
|
||||
func httpSetHeader(state *luajit.State) int {
|
||||
name, _ := state.SafeToString(1)
|
||||
value, _ := state.SafeToString(2)
|
||||
|
||||
state.GetGlobal("__response")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.SetGlobal("__response")
|
||||
state.GetGlobal("__response")
|
||||
}
|
||||
|
||||
state.GetField(-1, "headers")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.PushCopy(-1)
|
||||
state.SetField(-3, "headers")
|
||||
}
|
||||
|
||||
state.PushString(value)
|
||||
state.SetField(-2, name)
|
||||
state.Pop(2)
|
||||
return 0
|
||||
}
|
||||
|
||||
func httpRedirect(state *luajit.State) int {
|
||||
url, _ := state.SafeToString(1)
|
||||
status := 302.0
|
||||
if state.GetTop() >= 2 {
|
||||
status, _ = state.SafeToNumber(2)
|
||||
}
|
||||
|
||||
state.GetGlobal("__response")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.SetGlobal("__response")
|
||||
state.GetGlobal("__response")
|
||||
}
|
||||
|
||||
state.PushNumber(status)
|
||||
state.SetField(-2, "status")
|
||||
|
||||
state.GetField(-1, "headers")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.PushCopy(-1)
|
||||
state.SetField(-3, "headers")
|
||||
}
|
||||
|
||||
state.PushString(url)
|
||||
state.SetField(-2, "Location")
|
||||
state.Pop(2)
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *Server) sessionGet(state *luajit.State) int {
|
||||
key, _ := state.SafeToString(1)
|
||||
|
||||
state.GetGlobal("__session")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
state.GetField(-1, "data")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(2)
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
state.GetField(-1, key)
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s *Server) sessionSet(state *luajit.State) int {
|
||||
key, _ := state.SafeToString(1)
|
||||
|
||||
state.GetGlobal("__session")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.SetGlobal("__session")
|
||||
state.GetGlobal("__session")
|
||||
}
|
||||
|
||||
state.GetField(-1, "data")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.PushCopy(-1)
|
||||
state.SetField(-3, "data")
|
||||
}
|
||||
|
||||
value, err := state.ToValue(2)
|
||||
if err == nil {
|
||||
state.PushValue(value)
|
||||
state.SetField(-2, key)
|
||||
}
|
||||
state.Pop(2)
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *Server) sessionFlash(state *luajit.State) int {
|
||||
key, _ := state.SafeToString(1)
|
||||
|
||||
state.GetGlobal("__session")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.SetGlobal("__session")
|
||||
state.GetGlobal("__session")
|
||||
}
|
||||
|
||||
state.GetField(-1, "flash")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.PushCopy(-1)
|
||||
state.SetField(-3, "flash")
|
||||
}
|
||||
|
||||
value, err := state.ToValue(2)
|
||||
if err == nil {
|
||||
state.PushValue(value)
|
||||
state.SetField(-2, key)
|
||||
}
|
||||
state.Pop(2)
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *Server) sessionGetFlash(state *luajit.State) int {
|
||||
key, _ := state.SafeToString(1)
|
||||
|
||||
state.GetGlobal("__session")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
state.GetField(-1, "flash")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(2)
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
state.GetField(-1, key)
|
||||
return 1
|
||||
}
|
||||
|
||||
func cookieSet(state *luajit.State) int {
|
||||
name, _ := state.SafeToString(1)
|
||||
value, _ := state.SafeToString(2)
|
||||
|
||||
maxAge := 0
|
||||
path := "/"
|
||||
domain := ""
|
||||
secure := false
|
||||
httpOnly := true
|
||||
|
||||
if state.GetTop() >= 3 && state.IsTable(3) {
|
||||
maxAge = int(state.GetFieldNumber(3, "max_age", 0))
|
||||
path = state.GetFieldString(3, "path", "/")
|
||||
domain = state.GetFieldString(3, "domain", "")
|
||||
secure = state.GetFieldBool(3, "secure", false)
|
||||
httpOnly = state.GetFieldBool(3, "http_only", true)
|
||||
}
|
||||
|
||||
state.GetGlobal("__response")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.SetGlobal("__response")
|
||||
state.GetGlobal("__response")
|
||||
}
|
||||
|
||||
state.GetField(-1, "cookies")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.PushCopy(-1)
|
||||
state.SetField(-3, "cookies")
|
||||
}
|
||||
|
||||
cookieData := map[string]any{
|
||||
"name": name,
|
||||
"value": value,
|
||||
"path": path,
|
||||
"secure": secure,
|
||||
"http_only": httpOnly,
|
||||
}
|
||||
if domain != "" {
|
||||
cookieData["domain"] = domain
|
||||
}
|
||||
if maxAge > 0 {
|
||||
cookieData["max_age"] = maxAge
|
||||
}
|
||||
|
||||
state.PushValue(cookieData)
|
||||
|
||||
length := globalServer.getTableLength(-2)
|
||||
state.PushNumber(float64(length + 1))
|
||||
state.PushCopy(-2)
|
||||
state.SetTable(-4)
|
||||
|
||||
state.Pop(3)
|
||||
return 0
|
||||
}
|
||||
|
||||
func cookieGet(state *luajit.State) int {
|
||||
name, _ := state.SafeToString(1)
|
||||
|
||||
state.GetGlobal("__request")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
state.GetField(-1, "cookies")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(2)
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
state.GetField(-1, name)
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s *Server) csrfGenerate(state *luajit.State) int {
|
||||
token := s.generateCSRFToken()
|
||||
|
||||
state.GetGlobal("__session")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.SetGlobal("__session")
|
||||
state.GetGlobal("__session")
|
||||
}
|
||||
|
||||
state.GetField(-1, "data")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.NewTable()
|
||||
state.PushCopy(-1)
|
||||
state.SetField(-3, "data")
|
||||
}
|
||||
|
||||
state.PushString(token)
|
||||
state.SetField(-2, "_csrf_token")
|
||||
state.Pop(2)
|
||||
|
||||
state.PushString(token)
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s *Server) csrfValidate(state *luajit.State) int {
|
||||
state.GetGlobal("__session")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
sessionToken := state.GetFieldString(-1, "data._csrf_token", "")
|
||||
state.Pop(1)
|
||||
|
||||
state.GetGlobal("__request")
|
||||
if state.IsNil(-1) {
|
||||
state.Pop(1)
|
||||
state.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
requestToken := state.GetFieldString(-1, "form._csrf_token", "")
|
||||
state.Pop(1)
|
||||
|
||||
state.PushBoolean(sessionToken != "" && sessionToken == requestToken)
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s *Server) getTableLength(index int) int {
|
||||
length := 0
|
||||
s.state.PushNil()
|
||||
for s.state.Next(index - 1) {
|
||||
length++
|
||||
s.state.Pop(1)
|
||||
}
|
||||
return length
|
||||
}
|
115
http/http.lua
Normal file
115
http/http.lua
Normal file
@ -0,0 +1,115 @@
|
||||
http = {}
|
||||
|
||||
function http.listen(port)
|
||||
return __http_listen(port)
|
||||
end
|
||||
|
||||
function http.route(method, path, handler)
|
||||
return __http_route(method, path, handler)
|
||||
end
|
||||
|
||||
function http.status(code)
|
||||
return __http_set_status(code)
|
||||
end
|
||||
|
||||
function http.header(name, value)
|
||||
return __http_set_header(name, value)
|
||||
end
|
||||
|
||||
function http.redirect(url, status)
|
||||
__http_redirect(url, status or 302)
|
||||
coroutine.yield() -- Exit handler
|
||||
end
|
||||
|
||||
function http.json(data)
|
||||
http.header("Content-Type", "application/json")
|
||||
return json.encode(data)
|
||||
end
|
||||
|
||||
function http.html(content)
|
||||
http.header("Content-Type", "text/html")
|
||||
return content
|
||||
end
|
||||
|
||||
function http.text(content)
|
||||
http.header("Content-Type", "text/plain")
|
||||
return content
|
||||
end
|
||||
|
||||
-- Session functions
|
||||
session = {}
|
||||
|
||||
function session.get(key)
|
||||
return __session_get(key)
|
||||
end
|
||||
|
||||
function session.set(key, value)
|
||||
return __session_set(key, value)
|
||||
end
|
||||
|
||||
function session.flash(key, value)
|
||||
return __session_flash(key, value)
|
||||
end
|
||||
|
||||
function session.get_flash(key)
|
||||
return __session_get_flash(key)
|
||||
end
|
||||
|
||||
-- Cookie functions
|
||||
cookie = {}
|
||||
|
||||
function cookie.set(name, value, options)
|
||||
return __cookie_set(name, value, options)
|
||||
end
|
||||
|
||||
function cookie.get(name)
|
||||
return __cookie_get(name)
|
||||
end
|
||||
|
||||
-- CSRF functions
|
||||
csrf = {}
|
||||
|
||||
function csrf.generate()
|
||||
return __csrf_generate()
|
||||
end
|
||||
|
||||
function csrf.validate()
|
||||
return __csrf_validate()
|
||||
end
|
||||
|
||||
function csrf.field()
|
||||
local token = csrf.generate()
|
||||
return string.format('<input type="hidden" name="_csrf_token" value="%s" />', token)
|
||||
end
|
||||
|
||||
-- Helper functions
|
||||
function redirect_with_flash(url, type, message)
|
||||
session.flash(type, message)
|
||||
http.redirect(url)
|
||||
end
|
||||
|
||||
-- JSON encoding/decoding placeholder
|
||||
json = {
|
||||
encode = function(data)
|
||||
-- Simplified JSON encoding
|
||||
if type(data) == "table" then
|
||||
local result = "{"
|
||||
local first = true
|
||||
for k, v in pairs(data) do
|
||||
if not first then result = result .. "," end
|
||||
result = result .. '"' .. tostring(k) .. '":' .. json.encode(v)
|
||||
first = false
|
||||
end
|
||||
return result .. "}"
|
||||
elseif type(data) == "string" then
|
||||
return '"' .. data .. '"'
|
||||
else
|
||||
return tostring(data)
|
||||
end
|
||||
end,
|
||||
|
||||
decode = function(str)
|
||||
-- Simplified JSON decoding - you'd want a proper implementation
|
||||
return {}
|
||||
end
|
||||
}
|
260
http/router/router.go
Normal file
260
http/router/router.go
Normal file
@ -0,0 +1,260 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// node represents a node in the radix trie
|
||||
type node struct {
|
||||
segment string
|
||||
handler int // Lua function reference
|
||||
children []*node
|
||||
isDynamic bool // :param
|
||||
isWildcard bool // *param
|
||||
paramName string
|
||||
}
|
||||
|
||||
// Router is a string-based HTTP router with efficient lookup
|
||||
type Router struct {
|
||||
get, post, put, patch, delete *node
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Params holds URL parameters
|
||||
type Params struct {
|
||||
Keys []string
|
||||
Values []string
|
||||
}
|
||||
|
||||
// Get returns a parameter value by name
|
||||
func (p *Params) Get(name string) string {
|
||||
for i, key := range p.Keys {
|
||||
if key == name && i < len(p.Values) {
|
||||
return p.Values[i]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// New creates a new Router instance
|
||||
func New() *Router {
|
||||
return &Router{
|
||||
get: &node{},
|
||||
post: &node{},
|
||||
put: &node{},
|
||||
patch: &node{},
|
||||
delete: &node{},
|
||||
}
|
||||
}
|
||||
|
||||
// methodNode returns the root node for a method
|
||||
func (r *Router) methodNode(method string) *node {
|
||||
switch method {
|
||||
case "GET":
|
||||
return r.get
|
||||
case "POST":
|
||||
return r.post
|
||||
case "PUT":
|
||||
return r.put
|
||||
case "PATCH":
|
||||
return r.patch
|
||||
case "DELETE":
|
||||
return r.delete
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// AddRoute adds a new route with handler reference
|
||||
func (r *Router) AddRoute(method, path string, handlerRef int) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
root := r.methodNode(method)
|
||||
if root == nil {
|
||||
return errors.New("unsupported HTTP method")
|
||||
}
|
||||
|
||||
if path == "/" {
|
||||
root.handler = handlerRef
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.addRoute(root, path, handlerRef)
|
||||
}
|
||||
|
||||
// addRoute adds a route to the trie
|
||||
func (r *Router) addRoute(root *node, path string, handlerRef int) error {
|
||||
segments := r.parseSegments(path)
|
||||
current := root
|
||||
|
||||
for _, seg := range segments {
|
||||
isDyn := strings.HasPrefix(seg, ":")
|
||||
isWC := strings.HasPrefix(seg, "*")
|
||||
|
||||
if isWC && seg != segments[len(segments)-1] {
|
||||
return errors.New("wildcard must be the last segment")
|
||||
}
|
||||
|
||||
paramName := ""
|
||||
if isDyn {
|
||||
paramName = seg[1:]
|
||||
seg = ":"
|
||||
} else if isWC {
|
||||
paramName = seg[1:]
|
||||
seg = "*"
|
||||
}
|
||||
|
||||
// Find or create child
|
||||
var child *node
|
||||
for _, c := range current.children {
|
||||
if c.segment == seg {
|
||||
child = c
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if child == nil {
|
||||
child = &node{
|
||||
segment: seg,
|
||||
isDynamic: isDyn,
|
||||
isWildcard: isWC,
|
||||
paramName: paramName,
|
||||
}
|
||||
current.children = append(current.children, child)
|
||||
}
|
||||
|
||||
current = child
|
||||
}
|
||||
|
||||
current.handler = handlerRef
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseSegments splits path into segments
|
||||
func (r *Router) parseSegments(path string) []string {
|
||||
segments := strings.Split(strings.Trim(path, "/"), "/")
|
||||
var result []string
|
||||
for _, seg := range segments {
|
||||
if seg != "" {
|
||||
result = append(result, seg)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Lookup finds handler and parameters for a method and path
|
||||
func (r *Router) Lookup(method, path string) (int, *Params, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
root := r.methodNode(method)
|
||||
if root == nil {
|
||||
return 0, nil, false
|
||||
}
|
||||
|
||||
if path == "/" {
|
||||
if root.handler != 0 {
|
||||
return root.handler, &Params{}, true
|
||||
}
|
||||
return 0, nil, false
|
||||
}
|
||||
|
||||
segments := r.parseSegments(path)
|
||||
handler, params := r.match(root, segments, 0)
|
||||
if handler == 0 {
|
||||
return 0, nil, false
|
||||
}
|
||||
|
||||
return handler, params, true
|
||||
}
|
||||
|
||||
// match traverses the trie to find handler
|
||||
func (r *Router) match(current *node, segments []string, index int) (int, *Params) {
|
||||
if index >= len(segments) {
|
||||
if current.handler != 0 {
|
||||
return current.handler, &Params{}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
segment := segments[index]
|
||||
|
||||
// Check exact match first
|
||||
for _, child := range current.children {
|
||||
if child.segment == segment {
|
||||
handler, params := r.match(child, segments, index+1)
|
||||
if handler != 0 {
|
||||
return handler, params
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check dynamic match second
|
||||
for _, child := range current.children {
|
||||
if child.isDynamic {
|
||||
handler, params := r.match(child, segments, index+1)
|
||||
if handler != 0 {
|
||||
// Prepend this parameter
|
||||
newParams := &Params{
|
||||
Keys: append([]string{child.paramName}, params.Keys...),
|
||||
Values: append([]string{segment}, params.Values...),
|
||||
}
|
||||
return handler, newParams
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check wildcard last (catches everything remaining)
|
||||
for _, child := range current.children {
|
||||
if child.isWildcard {
|
||||
remaining := strings.Join(segments[index:], "/")
|
||||
return child.handler, &Params{
|
||||
Keys: []string{child.paramName},
|
||||
Values: []string{remaining},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// RemoveRoute removes a route
|
||||
func (r *Router) RemoveRoute(method, path string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
root := r.methodNode(method)
|
||||
if root == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if path == "/" {
|
||||
root.handler = 0
|
||||
return
|
||||
}
|
||||
|
||||
segments := r.parseSegments(path)
|
||||
r.removeRoute(root, segments, 0)
|
||||
}
|
||||
|
||||
// removeRoute removes a route from the trie
|
||||
func (r *Router) removeRoute(current *node, segments []string, index int) {
|
||||
if index >= len(segments) {
|
||||
current.handler = 0
|
||||
return
|
||||
}
|
||||
|
||||
segment := segments[index]
|
||||
|
||||
for _, child := range current.children {
|
||||
if child.segment == segment ||
|
||||
(child.isDynamic && strings.HasPrefix(segment, ":")) ||
|
||||
(child.isWildcard && strings.HasPrefix(segment, "*")) {
|
||||
r.removeRoute(child, segments, index+1)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
167
http/server.go
167
http/server.go
@ -1,167 +0,0 @@
|
||||
// server.go - Simplified HTTP server
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"Moonshark/config"
|
||||
"Moonshark/metadata"
|
||||
"Moonshark/runner"
|
||||
"Moonshark/utils"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/valyala/bytebufferpool"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func NewHttpServer(cfg *config.Config, handler fasthttp.RequestHandler, dbg bool) *fasthttp.Server {
|
||||
return &fasthttp.Server{
|
||||
Handler: handler,
|
||||
Name: "Moonshark/" + metadata.Version,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
MaxRequestBodySize: 16 << 20,
|
||||
TCPKeepalive: true,
|
||||
ReduceMemoryUsage: true,
|
||||
StreamRequestBody: true,
|
||||
NoDefaultServerHeader: true,
|
||||
}
|
||||
}
|
||||
|
||||
func NewPublicHandler(pubDir, prefix string) fasthttp.RequestHandler {
|
||||
if !strings.HasPrefix(prefix, "/") {
|
||||
prefix = "/" + prefix
|
||||
}
|
||||
if !strings.HasSuffix(prefix, "/") {
|
||||
prefix += "/"
|
||||
}
|
||||
|
||||
fs := &fasthttp.FS{
|
||||
Root: pubDir,
|
||||
IndexNames: []string{"index.html"},
|
||||
AcceptByteRange: true,
|
||||
Compress: true,
|
||||
CompressedFileSuffix: ".gz",
|
||||
CompressBrotli: true,
|
||||
PathRewrite: fasthttp.NewPathPrefixStripper(len(prefix) - 1),
|
||||
}
|
||||
return fs.NewRequestHandler()
|
||||
}
|
||||
|
||||
func Send404(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetBody([]byte(utils.NotFoundPage(ctx.URI().String())))
|
||||
}
|
||||
|
||||
func Send500(ctx *fasthttp.RequestCtx, err error) {
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
|
||||
if err == nil {
|
||||
ctx.SetBody([]byte(utils.InternalErrorPage(string(ctx.Path()), "")))
|
||||
} else {
|
||||
ctx.SetBody([]byte(utils.InternalErrorPage(string(ctx.Path()), err.Error())))
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyResponse applies a Response to a fasthttp.RequestCtx
|
||||
func ApplyResponse(resp *runner.Response, ctx *fasthttp.RequestCtx) {
|
||||
// Set status code
|
||||
ctx.SetStatusCode(resp.Status)
|
||||
|
||||
// Set headers
|
||||
for name, value := range resp.Headers {
|
||||
ctx.Response.Header.Set(name, value)
|
||||
}
|
||||
|
||||
// Set cookies
|
||||
for _, cookie := range resp.Cookies {
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
}
|
||||
|
||||
// Process the body based on its type
|
||||
if resp.Body == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if Content-Type was manually set
|
||||
contentTypeSet := false
|
||||
for name := range resp.Headers {
|
||||
if strings.ToLower(name) == "content-type" {
|
||||
contentTypeSet = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Get a buffer from the pool
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
// Set body based on type
|
||||
switch body := resp.Body.(type) {
|
||||
case string:
|
||||
if !contentTypeSet {
|
||||
ctx.Response.Header.SetContentType("text/plain; charset=utf-8")
|
||||
}
|
||||
ctx.SetBodyString(body)
|
||||
case []byte:
|
||||
if !contentTypeSet {
|
||||
ctx.Response.Header.SetContentType("text/plain; charset=utf-8")
|
||||
}
|
||||
ctx.SetBody(body)
|
||||
case map[string]any, map[any]any, []any, []float64, []string, []int, []map[string]any:
|
||||
// Marshal JSON
|
||||
if err := json.NewEncoder(buf).Encode(body); err == nil {
|
||||
if !contentTypeSet {
|
||||
ctx.Response.Header.SetContentType("application/json")
|
||||
}
|
||||
ctx.SetBody(buf.Bytes())
|
||||
} else {
|
||||
// Fallback to string representation
|
||||
if !contentTypeSet {
|
||||
ctx.Response.Header.SetContentType("text/plain; charset=utf-8")
|
||||
}
|
||||
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
||||
}
|
||||
default:
|
||||
// Check if it's any other map or slice type
|
||||
typeStr := fmt.Sprintf("%T", body)
|
||||
if typeStr[0] == '[' || (len(typeStr) > 3 && typeStr[:3] == "map") {
|
||||
if err := json.NewEncoder(buf).Encode(body); err == nil {
|
||||
if !contentTypeSet {
|
||||
ctx.Response.Header.SetContentType("application/json")
|
||||
}
|
||||
ctx.SetBody(buf.Bytes())
|
||||
} else {
|
||||
if !contentTypeSet {
|
||||
ctx.Response.Header.SetContentType("text/plain; charset=utf-8")
|
||||
}
|
||||
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
||||
}
|
||||
} else {
|
||||
// Default to string representation
|
||||
if !contentTypeSet {
|
||||
ctx.Response.Header.SetContentType("text/plain; charset=utf-8")
|
||||
}
|
||||
ctx.SetBodyString(fmt.Sprintf("%v", body))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
func HandleDebugStats(ctx *fasthttp.RequestCtx) {
|
||||
stats := utils.CollectSystemStats(s.cfg)
|
||||
stats.Components = utils.ComponentStats{
|
||||
RouteCount: 0, // TODO: Get from router
|
||||
BytecodeBytes: 0, // TODO: Get from router
|
||||
SessionStats: s.sessionManager.GetCacheStats(),
|
||||
}
|
||||
ctx.SetContentType("text/html; charset=utf-8")
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
ctx.SetBody([]byte(utils.DebugStatsPage(stats)))
|
||||
}
|
||||
*/
|
208
logger/logger.go
208
logger/logger.go
@ -1,208 +0,0 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"git.sharkk.net/Go/Color"
|
||||
)
|
||||
|
||||
// Log levels
|
||||
const (
|
||||
LevelDebug = iota
|
||||
LevelInfo
|
||||
LevelWarn
|
||||
LevelError
|
||||
LevelFatal
|
||||
)
|
||||
|
||||
// Level config
|
||||
var levels = map[int]struct {
|
||||
tag string
|
||||
color func(string) string
|
||||
}{
|
||||
LevelDebug: {"D", color.Cyan},
|
||||
LevelInfo: {"I", color.Blue},
|
||||
LevelWarn: {"W", color.Yellow},
|
||||
LevelError: {"E", color.Red},
|
||||
LevelFatal: {"F", color.Purple},
|
||||
}
|
||||
|
||||
var (
|
||||
global *Logger
|
||||
globalOnce sync.Once
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
out io.Writer
|
||||
enabled atomic.Bool
|
||||
timestamp atomic.Bool
|
||||
debug atomic.Bool
|
||||
http atomic.Bool
|
||||
colors atomic.Bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func init() {
|
||||
globalOnce.Do(func() {
|
||||
global = &Logger{out: os.Stdout}
|
||||
global.enabled.Store(true)
|
||||
global.timestamp.Store(true)
|
||||
global.colors.Store(true)
|
||||
})
|
||||
}
|
||||
|
||||
func applyColor(text string, colorFunc func(string) string) string {
|
||||
if global.colors.Load() {
|
||||
return colorFunc(text)
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func write(level int, msg string) {
|
||||
if !global.enabled.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
if level == LevelDebug && !global.debug.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
var parts []string
|
||||
|
||||
if global.timestamp.Load() {
|
||||
ts := applyColor(time.Now().Format("3:04PM"), color.Gray)
|
||||
parts = append(parts, ts)
|
||||
}
|
||||
|
||||
if cfg, ok := levels[level]; ok {
|
||||
tag := applyColor("["+cfg.tag+"]", cfg.color)
|
||||
parts = append(parts, tag)
|
||||
}
|
||||
|
||||
parts = append(parts, msg)
|
||||
line := strings.Join(parts, " ") + "\n"
|
||||
|
||||
global.mu.Lock()
|
||||
fmt.Fprint(global.out, line)
|
||||
if level == LevelFatal {
|
||||
if f, ok := global.out.(*os.File); ok {
|
||||
f.Sync()
|
||||
}
|
||||
}
|
||||
global.mu.Unlock()
|
||||
|
||||
if level == LevelFatal {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func log(level int, format string, args ...any) {
|
||||
var msg string
|
||||
if len(args) > 0 {
|
||||
msg = fmt.Sprintf(format, args...)
|
||||
} else {
|
||||
msg = format
|
||||
}
|
||||
write(level, msg)
|
||||
}
|
||||
|
||||
func Debugf(format string, args ...any) { log(LevelDebug, format, args...) }
|
||||
func Infof(format string, args ...any) { log(LevelInfo, format, args...) }
|
||||
func Warnf(format string, args ...any) { log(LevelWarn, format, args...) }
|
||||
func Errorf(format string, args ...any) { log(LevelError, format, args...) }
|
||||
func Fatalf(format string, args ...any) { log(LevelFatal, format, args...) }
|
||||
|
||||
func Raw(format string, args ...any) {
|
||||
if !global.enabled.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
var msg string
|
||||
if len(args) > 0 {
|
||||
msg = fmt.Sprintf(format, args...)
|
||||
} else {
|
||||
msg = format
|
||||
}
|
||||
|
||||
global.mu.Lock()
|
||||
fmt.Fprint(global.out, msg+"\n")
|
||||
global.mu.Unlock()
|
||||
}
|
||||
|
||||
// Attempts to write the HTTP request result to the log. Will not print if
|
||||
// http is disabled on the logger. Method and path take byte slices for convenience.
|
||||
func Request(status int, method, path []byte, duration time.Duration) {
|
||||
if !global.enabled.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
if !global.http.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
var statusColor func(string) string
|
||||
switch {
|
||||
case status < 300:
|
||||
statusColor = color.Green
|
||||
case status < 400:
|
||||
statusColor = color.Cyan
|
||||
case status < 500:
|
||||
statusColor = color.Yellow
|
||||
default:
|
||||
statusColor = color.Red
|
||||
}
|
||||
|
||||
var dur string
|
||||
us := duration.Microseconds()
|
||||
switch {
|
||||
case us < 1000:
|
||||
dur = fmt.Sprintf("%.0fµs", float64(us))
|
||||
case us < 1000000:
|
||||
dur = fmt.Sprintf("%.1fms", float64(us)/1000)
|
||||
default:
|
||||
dur = fmt.Sprintf("%.2fs", duration.Seconds())
|
||||
}
|
||||
|
||||
var parts []string
|
||||
|
||||
if global.timestamp.Load() {
|
||||
ts := applyColor(time.Now().Format("3:04PM"), color.Gray)
|
||||
parts = append(parts, ts)
|
||||
}
|
||||
|
||||
parts = append(parts,
|
||||
applyColor("["+string(method)+"]", color.Gray),
|
||||
applyColor(fmt.Sprintf("%d", status), statusColor),
|
||||
applyColor(string(path), color.Gray),
|
||||
applyColor(dur, color.Gray),
|
||||
)
|
||||
|
||||
msg := strings.Join(parts, " ")
|
||||
|
||||
global.mu.Lock()
|
||||
fmt.Fprint(global.out, msg+"\n")
|
||||
global.mu.Unlock()
|
||||
}
|
||||
|
||||
func SetOutput(w io.Writer) {
|
||||
global.mu.Lock()
|
||||
global.out = w
|
||||
global.mu.Unlock()
|
||||
}
|
||||
|
||||
func Enable() { global.enabled.Store(true) }
|
||||
func Disable() { global.enabled.Store(false) }
|
||||
func IsEnabled() bool { return global.enabled.Load() }
|
||||
func EnableColors() { global.colors.Store(true) }
|
||||
func DisableColors() { global.colors.Store(false) }
|
||||
func ColorsEnabled() bool { return global.colors.Load() }
|
||||
func Timestamp(enabled bool) { global.timestamp.Store(enabled) }
|
||||
func Debug(enabled bool) { global.debug.Store(enabled) }
|
||||
func IsDebug() bool { return global.debug.Load() }
|
||||
func Http(enabled bool) { global.debug.Store(enabled) }
|
374
moonshark.go
374
moonshark.go
@ -1,362 +1,68 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"Moonshark/config"
|
||||
"Moonshark/http"
|
||||
"Moonshark/logger"
|
||||
"Moonshark/metadata"
|
||||
"Moonshark/router"
|
||||
"Moonshark/runner"
|
||||
"Moonshark/sessions"
|
||||
"Moonshark/utils"
|
||||
"Moonshark/watchers"
|
||||
"bytes"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
color "git.sharkk.net/Go/Color"
|
||||
"Moonshark/http"
|
||||
|
||||
fin "git.sharkk.net/Sharkk/Fin"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
var (
|
||||
cfg *config.Config // Server config from Fin file
|
||||
rtr *router.Router // Lua file router
|
||||
rnr *runner.Runner // Lua runner
|
||||
svr *fasthttp.Server // FastHTTP server
|
||||
pub fasthttp.RequestHandler // Public asset handler
|
||||
snm *sessions.SessionManager // Session data manager
|
||||
wmg *watchers.WatcherManager // Watcher manager
|
||||
dbg bool // Debug mode flag
|
||||
pubPfx []byte // Cached public asset prefix
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cfgPath := flag.String("config", "config", "Path to Fin config file")
|
||||
dbgFlag := flag.Bool("debug", false, "Force debug mode")
|
||||
sptPath := flag.String("script", "", "Path to Lua script to execute once")
|
||||
flag.Parse()
|
||||
|
||||
// Init sequence
|
||||
sptMode := *sptPath != ""
|
||||
color.SetColors(color.DetectShellColors())
|
||||
banner(sptMode)
|
||||
|
||||
// Load Fin-based config
|
||||
cfg = config.New(fin.LoadFromFile(*cfgPath))
|
||||
|
||||
// Setup debug mode
|
||||
dbg = *dbgFlag || cfg.Server.Debug
|
||||
logger.Debug(dbg)
|
||||
logger.Debugf("Debug logging enabled") // Only prints if dbg is true
|
||||
utils.Debug(dbg) // @TODO find a better way to do this
|
||||
|
||||
// Determine Lua runner pool size
|
||||
poolSize := cfg.Runner.PoolSize
|
||||
if sptMode {
|
||||
poolSize = 1
|
||||
if len(os.Args) < 2 {
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s <lua_file>\n", os.Args[0])
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Set up the Lua runner
|
||||
if err := initRunner(poolSize); err != nil {
|
||||
logger.Fatalf("Runner failed to init: %v", err)
|
||||
luaFile := os.Args[1]
|
||||
|
||||
if _, err := os.Stat(luaFile); os.IsNotExist(err) {
|
||||
fmt.Fprintf(os.Stderr, "Error: File '%s' not found\n", luaFile)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// If in script mode, attempt to run the Lua script at the given path
|
||||
if sptMode {
|
||||
if err := handleScriptMode(*sptPath); err != nil {
|
||||
logger.Fatalf("Script execution failed: %v", err)
|
||||
}
|
||||
|
||||
shutdown()
|
||||
return
|
||||
}
|
||||
|
||||
// Set up the Lua router
|
||||
if err := initRouter(); err != nil {
|
||||
logger.Fatalf("Router failed to init: %s", color.Red(err.Error()))
|
||||
}
|
||||
|
||||
// Set up the file watcher manager
|
||||
if err := setupWatchers(); err != nil {
|
||||
logger.Fatalf("Watcher manager failed to init: %s", color.Red(err.Error()))
|
||||
}
|
||||
|
||||
// Set up the HTTP portion of the server
|
||||
logger.Http(cfg.Server.HTTPLogging) // Whether we'll log HTTP request results
|
||||
svr = http.NewHttpServer(cfg, requestMux, dbg)
|
||||
pub = http.NewPublicHandler(cfg.Dirs.Public, cfg.Server.PublicPrefix)
|
||||
pubPfx = []byte(cfg.Server.PublicPrefix) // Avoids casting to []byte when check prefixes
|
||||
snm = sessions.NewSessionManager(sessions.DefaultMaxSessions)
|
||||
|
||||
// Start the HTTP server
|
||||
logger.Infof("Surf's up on port %s!", color.Cyan(strconv.Itoa(cfg.Server.Port)))
|
||||
go func() {
|
||||
if err := svr.ListenAndServe(":" + strconv.Itoa(cfg.Server.Port)); err != nil {
|
||||
if err.Error() != "http: Server closed" {
|
||||
logger.Errorf("Server error: %v", err)
|
||||
}
|
||||
// Create long-lived LuaJIT state
|
||||
L := luajit.New(true)
|
||||
if L == nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: Failed to create Lua state\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
defer func() {
|
||||
L.Cleanup()
|
||||
L.Close()
|
||||
}()
|
||||
|
||||
// Handle a shutdown signal
|
||||
stop := make(chan os.Signal, 1)
|
||||
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
|
||||
<-stop
|
||||
|
||||
fmt.Print("\n")
|
||||
logger.Infof("Shutdown signal received")
|
||||
shutdown()
|
||||
// Register HTTP functions
|
||||
if err := http.RegisterHTTPFunctions(L); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error registering HTTP functions: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// This is the primary request handler mux - determines whether we need to handle a Lua
|
||||
// route or if we're serving a static file.
|
||||
func requestMux(ctx *fasthttp.RequestCtx) {
|
||||
start := time.Now()
|
||||
method := ctx.Method()
|
||||
path := ctx.Path()
|
||||
|
||||
// Handle static file request
|
||||
if bytes.HasPrefix(path, pubPfx) {
|
||||
pub(ctx)
|
||||
logRequest(ctx, method, path, start)
|
||||
return
|
||||
// Execute the Lua file
|
||||
if err := L.DoFile(luaFile); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// See if the requested route even exists
|
||||
bytecode, params, found := rtr.Lookup(string(method), string(path))
|
||||
if !found {
|
||||
http.Send404(ctx)
|
||||
logRequest(ctx, method, path, start)
|
||||
return
|
||||
// Handle return value for immediate exit
|
||||
if L.GetTop() > 0 {
|
||||
if L.IsNumber(1) {
|
||||
exitCode := int(L.ToNumber(1))
|
||||
if exitCode != 0 {
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
// If there's no bytecode then it's an internal server error
|
||||
if len(bytecode) == 0 {
|
||||
http.Send500(ctx, nil)
|
||||
logRequest(ctx, method, path, start)
|
||||
}
|
||||
|
||||
// We've made it this far so the endpoint will likely load. Let's get any session data
|
||||
// for this request
|
||||
session := snm.GetSessionFromRequest(ctx)
|
||||
|
||||
// Let's build an HTTP context for the Lua runner to consume
|
||||
luaCtx := runner.NewHTTPContext(ctx, params, session)
|
||||
defer luaCtx.Release()
|
||||
|
||||
// Ask the runner to execute our endpoint with our context
|
||||
res, err := rnr.Execute(bytecode, luaCtx)
|
||||
if err != nil {
|
||||
logger.Errorf("Lua execution error: %v", err)
|
||||
http.Send500(ctx, err)
|
||||
logRequest(ctx, method, path, start)
|
||||
return
|
||||
}
|
||||
|
||||
// Sweet, our execution went through! Let's now use the Response we got and build the HTTP response, then return
|
||||
// the response object to be cleaned. After, we'll log our request cus we are *done*
|
||||
applyResponse(ctx, res, session)
|
||||
runner.ReleaseResponse(res)
|
||||
logRequest(ctx, method, path, start)
|
||||
}
|
||||
|
||||
func applyResponse(ctx *fasthttp.RequestCtx, resp *runner.Response, session *sessions.Session) {
|
||||
// Handle session updates
|
||||
if len(resp.SessionData) > 0 {
|
||||
if _, clearAll := resp.SessionData["__clear_all"]; clearAll {
|
||||
session.Clear()
|
||||
session.ClearFlash()
|
||||
delete(resp.SessionData, "__clear_all")
|
||||
}
|
||||
|
||||
for k, v := range resp.SessionData {
|
||||
if v == "__DELETE__" {
|
||||
session.Delete(k)
|
||||
} else {
|
||||
session.Set(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle flash data
|
||||
if flashData, ok := resp.Metadata["flash"].(map[string]any); ok {
|
||||
for k, v := range flashData {
|
||||
if err := session.FlashSafe(k, v); err != nil && dbg {
|
||||
logger.Warnf("Error setting flash data %s: %v", k, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply session cookie
|
||||
snm.ApplySessionCookie(ctx, session)
|
||||
|
||||
// Apply HTTP response
|
||||
http.ApplyResponse(resp, ctx)
|
||||
}
|
||||
|
||||
// Attempts to start the Lua runner. poolSize allows overriding the config, like for script mode. A poolSize of
|
||||
// 0 will default to the config, and if the config is 0 then it will default to GOMAXPROCS.
|
||||
func initRunner(poolSize int) error {
|
||||
for _, dir := range cfg.Dirs.Libs {
|
||||
if !dirExists(dir) {
|
||||
logger.Warnf("Lib directory not found... %s", color.Yellow(dir))
|
||||
} else if L.IsBoolean(1) && !L.ToBoolean(1) {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
runner, err := runner.NewRunner(cfg, poolSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lua runner init failed: %v", err)
|
||||
}
|
||||
rnr = runner
|
||||
|
||||
logger.Infof("LuaRunner is g2g with %s states!", color.Yellow(strconv.Itoa(poolSize)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Attempt to spin up the Lua router. Attempts to create the routes directory if it doesn't exist,
|
||||
// since it's required for Moonshark to work.
|
||||
func initRouter() error {
|
||||
if err := os.MkdirAll(cfg.Dirs.Routes, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create routes directory: %w", err)
|
||||
}
|
||||
|
||||
router, err := router.New(cfg.Dirs.Routes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lua router init failed: %v", err)
|
||||
}
|
||||
rtr = router
|
||||
|
||||
logger.Infof("LuaRouter is g2g! %s", color.Yellow(cfg.Dirs.Routes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set up the file watchers.
|
||||
func setupWatchers() error {
|
||||
wmg = watchers.NewWatcherManager()
|
||||
|
||||
// Router watcher
|
||||
err := wmg.WatchDirectory(watchers.WatcherConfig{
|
||||
Dir: cfg.Dirs.Routes,
|
||||
Callback: rtr.Refresh,
|
||||
Recursive: true,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to watch routes directory: %v", err)
|
||||
}
|
||||
|
||||
logger.Infof("Started watching Lua routes! %s", color.Yellow(cfg.Dirs.Routes))
|
||||
|
||||
// Libs watchers
|
||||
for _, dir := range cfg.Dirs.Libs {
|
||||
err := wmg.WatchDirectory(watchers.WatcherConfig{
|
||||
Dir: dir,
|
||||
Callback: func(changes []watchers.FileChange) error {
|
||||
for _, change := range changes {
|
||||
if !change.IsDeleted && strings.HasSuffix(change.Path, ".lua") {
|
||||
rnr.NotifyFileChanged(change.Path)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Recursive: true,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to watch modules directory: %v", err)
|
||||
}
|
||||
|
||||
logger.Infof("Started watching Lua modules! %s", color.Yellow(dir))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Attempts to execute the Lua script at the given path inside a fully initialized sandbox environment. Handy
|
||||
// for pre-launch tasks and the like.
|
||||
func handleScriptMode(path string) error {
|
||||
path, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve script path: %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
return fmt.Errorf("script file not found: %s", path)
|
||||
}
|
||||
|
||||
logger.Infof("Executing: %s", path)
|
||||
|
||||
resp, err := rnr.RunScriptFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("execution failed: %v", err)
|
||||
}
|
||||
|
||||
if resp != nil && resp.Body != nil {
|
||||
logger.Infof("Script result: %v", resp.Body)
|
||||
} else {
|
||||
logger.Infof("Script executed successfully (no return value)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func shutdown() {
|
||||
logger.Infof("Shutting down...")
|
||||
|
||||
// Close down the HTTP server
|
||||
if svr != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := svr.ShutdownWithContext(ctx); err != nil {
|
||||
logger.Errorf("HTTP server shutdown error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Close down the Lua runner if it exists
|
||||
if rnr != nil {
|
||||
rnr.Close()
|
||||
}
|
||||
|
||||
// Close down the watcher manager if it exists
|
||||
if wmg != nil {
|
||||
wmg.Close()
|
||||
}
|
||||
|
||||
logger.Infof("Shutdown complete")
|
||||
}
|
||||
|
||||
// Print our super-awesome banner with the current version!
|
||||
func banner(scriptMode bool) {
|
||||
if scriptMode {
|
||||
fmt.Println(color.Blue(fmt.Sprintf("Moonshark %s << Script Mode >>", metadata.Version)))
|
||||
return
|
||||
}
|
||||
|
||||
banner := `
|
||||
_____ _________.__ __
|
||||
/ \ ____ ____ ____ / _____/| |__ _____ _______| | __
|
||||
/ \ / \ / _ \ / _ \ / \ \_____ \ | | \\__ \\_ __ \ |/ /
|
||||
/ Y ( <_> | <_> ) | \/ \| Y \/ __ \| | \/ <
|
||||
\____|__ /\____/ \____/|___| /_______ /|___| (____ /__| |__|_ \ %s
|
||||
\/ \/ \/ \/ \/ \/
|
||||
`
|
||||
fmt.Println(color.Blue(fmt.Sprintf(banner, metadata.Version)))
|
||||
}
|
||||
|
||||
func dirExists(path string) bool {
|
||||
info, err := os.Stat(path)
|
||||
return err == nil && info.IsDir()
|
||||
}
|
||||
|
||||
func logRequest(ctx *fasthttp.RequestCtx, method, path []byte, start time.Time) {
|
||||
logger.Request(ctx.Response.StatusCode(), method, path, time.Since(start))
|
||||
// Keep running for HTTP server
|
||||
fmt.Println("Script executed. Press Ctrl+C to exit.")
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
<-c
|
||||
fmt.Println("\nShutting down...")
|
||||
}
|
||||
|
502
router/router.go
502
router/router.go
@ -1,502 +0,0 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"Moonshark/watchers"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"github.com/VictoriaMetrics/fastcache"
|
||||
)
|
||||
|
||||
// node represents a node in the radix trie
|
||||
type node struct {
|
||||
segment string
|
||||
bytecode []byte
|
||||
scriptPath string
|
||||
children []*node
|
||||
isDynamic bool
|
||||
isWildcard bool
|
||||
maxParams uint8
|
||||
}
|
||||
|
||||
// Router is a filesystem-based HTTP router for Lua files with bytecode caching
|
||||
type Router struct {
|
||||
routesDir string
|
||||
get, post, put, patch, delete *node
|
||||
bytecodeCache *fastcache.Cache
|
||||
compileState *luajit.State
|
||||
compileMu sync.Mutex
|
||||
paramsBuffer []string
|
||||
middlewareFiles map[string][]string // filesystem path -> middleware file paths
|
||||
}
|
||||
|
||||
// Params holds URL parameters
|
||||
type Params struct {
|
||||
Keys []string
|
||||
Values []string
|
||||
}
|
||||
|
||||
// Get returns a parameter value by name
|
||||
func (p *Params) Get(name string) string {
|
||||
for i, key := range p.Keys {
|
||||
if key == name && i < len(p.Values) {
|
||||
return p.Values[i]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// New creates a new Router instance
|
||||
func New(routesDir string) (*Router, error) {
|
||||
info, err := os.Stat(routesDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return nil, errors.New("routes path is not a directory")
|
||||
}
|
||||
|
||||
compileState := luajit.New()
|
||||
if compileState == nil {
|
||||
return nil, errors.New("failed to create Lua compile state")
|
||||
}
|
||||
|
||||
r := &Router{
|
||||
routesDir: routesDir,
|
||||
get: &node{},
|
||||
post: &node{},
|
||||
put: &node{},
|
||||
patch: &node{},
|
||||
delete: &node{},
|
||||
bytecodeCache: fastcache.New(32 * 1024 * 1024), // 32MB
|
||||
compileState: compileState,
|
||||
paramsBuffer: make([]string, 64),
|
||||
middlewareFiles: make(map[string][]string),
|
||||
}
|
||||
|
||||
return r, r.buildRoutes()
|
||||
}
|
||||
|
||||
// methodNode returns the root node for a method
|
||||
func (r *Router) methodNode(method string) *node {
|
||||
switch method {
|
||||
case "GET":
|
||||
return r.get
|
||||
case "POST":
|
||||
return r.post
|
||||
case "PUT":
|
||||
return r.put
|
||||
case "PATCH":
|
||||
return r.patch
|
||||
case "DELETE":
|
||||
return r.delete
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// buildRoutes scans the routes directory and builds the routing tree
|
||||
func (r *Router) buildRoutes() error {
|
||||
r.middlewareFiles = make(map[string][]string)
|
||||
|
||||
// First pass: collect all middleware files
|
||||
err := filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".lua") {
|
||||
return err
|
||||
}
|
||||
|
||||
if strings.TrimSuffix(info.Name(), ".lua") == "middleware" {
|
||||
relDir, err := filepath.Rel(r.routesDir, filepath.Dir(path))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fsPath := "/"
|
||||
if relDir != "." {
|
||||
fsPath = "/" + strings.ReplaceAll(relDir, "\\", "/")
|
||||
}
|
||||
|
||||
r.middlewareFiles[fsPath] = append(r.middlewareFiles[fsPath], path)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Second pass: build routes
|
||||
return filepath.Walk(r.routesDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".lua") {
|
||||
return err
|
||||
}
|
||||
|
||||
fileName := strings.TrimSuffix(info.Name(), ".lua")
|
||||
|
||||
// Skip middleware files
|
||||
if fileName == "middleware" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get relative path from routes directory
|
||||
relPath, err := filepath.Rel(r.routesDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get filesystem path (includes groups)
|
||||
fsPath := "/" + strings.ReplaceAll(filepath.Dir(relPath), "\\", "/")
|
||||
if fsPath == "/." {
|
||||
fsPath = "/"
|
||||
}
|
||||
|
||||
// Get URL path (excludes groups)
|
||||
urlPath := r.parseURLPath(fsPath)
|
||||
|
||||
// Handle method files (get.lua, post.lua, etc.)
|
||||
method := strings.ToUpper(fileName)
|
||||
root := r.methodNode(method)
|
||||
if root != nil {
|
||||
return r.addRoute(root, urlPath, fsPath, path)
|
||||
}
|
||||
|
||||
// Handle index files - register for all methods
|
||||
if fileName == "index" {
|
||||
for _, method := range []string{"GET", "POST", "PUT", "PATCH", "DELETE"} {
|
||||
if root := r.methodNode(method); root != nil {
|
||||
if err := r.addRoute(root, urlPath, fsPath, path); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle named route files - register as GET by default
|
||||
namedPath := urlPath
|
||||
if urlPath == "/" {
|
||||
namedPath = "/" + fileName
|
||||
} else {
|
||||
namedPath = urlPath + "/" + fileName
|
||||
}
|
||||
return r.addRoute(r.get, namedPath, fsPath, path)
|
||||
})
|
||||
}
|
||||
|
||||
// parseURLPath strips group segments from filesystem path
|
||||
func (r *Router) parseURLPath(fsPath string) string {
|
||||
segments := strings.Split(strings.Trim(fsPath, "/"), "/")
|
||||
var urlSegments []string
|
||||
|
||||
for _, segment := range segments {
|
||||
if segment == "" {
|
||||
continue
|
||||
}
|
||||
// Skip group segments (enclosed in parentheses)
|
||||
if strings.HasPrefix(segment, "(") && strings.HasSuffix(segment, ")") {
|
||||
continue
|
||||
}
|
||||
urlSegments = append(urlSegments, segment)
|
||||
}
|
||||
|
||||
if len(urlSegments) == 0 {
|
||||
return "/"
|
||||
}
|
||||
return "/" + strings.Join(urlSegments, "/")
|
||||
}
|
||||
|
||||
// getMiddlewareChain returns middleware files that apply to the given filesystem path
|
||||
func (r *Router) getMiddlewareChain(fsPath string) []string {
|
||||
var chain []string
|
||||
|
||||
pathParts := strings.Split(strings.Trim(fsPath, "/"), "/")
|
||||
if pathParts[0] == "" {
|
||||
pathParts = []string{}
|
||||
}
|
||||
|
||||
// Add root middleware
|
||||
if mw, exists := r.middlewareFiles["/"]; exists {
|
||||
chain = append(chain, mw...)
|
||||
}
|
||||
|
||||
// Add middleware from each path level (including groups)
|
||||
currentPath := ""
|
||||
for _, part := range pathParts {
|
||||
currentPath += "/" + part
|
||||
if mw, exists := r.middlewareFiles[currentPath]; exists {
|
||||
chain = append(chain, mw...)
|
||||
}
|
||||
}
|
||||
|
||||
return chain
|
||||
}
|
||||
|
||||
// buildCombinedSource combines middleware and handler source
|
||||
func (r *Router) buildCombinedSource(fsPath, scriptPath string) (string, error) {
|
||||
var combined strings.Builder
|
||||
|
||||
// Add middleware in order
|
||||
middlewareChain := r.getMiddlewareChain(fsPath)
|
||||
for _, mwPath := range middlewareChain {
|
||||
content, err := os.ReadFile(mwPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
combined.WriteString("-- Middleware: ")
|
||||
combined.WriteString(mwPath)
|
||||
combined.WriteString("\n")
|
||||
combined.Write(content)
|
||||
combined.WriteString("\n")
|
||||
}
|
||||
|
||||
// Add main handler
|
||||
content, err := os.ReadFile(scriptPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
combined.WriteString("-- Handler: ")
|
||||
combined.WriteString(scriptPath)
|
||||
combined.WriteString("\n")
|
||||
combined.Write(content)
|
||||
|
||||
return combined.String(), nil
|
||||
}
|
||||
|
||||
// addRoute adds a new route to the trie with bytecode compilation
|
||||
func (r *Router) addRoute(root *node, urlPath, fsPath, scriptPath string) error {
|
||||
// Build combined source with middleware
|
||||
combinedSource, err := r.buildCombinedSource(fsPath, scriptPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Compile bytecode
|
||||
r.compileMu.Lock()
|
||||
bytecode, err := r.compileState.CompileBytecode(combinedSource, scriptPath)
|
||||
r.compileMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Cache bytecode
|
||||
cacheKey := hashString(scriptPath)
|
||||
r.bytecodeCache.Set(uint64ToBytes(cacheKey), bytecode)
|
||||
|
||||
if urlPath == "/" {
|
||||
root.bytecode = bytecode
|
||||
root.scriptPath = scriptPath
|
||||
return nil
|
||||
}
|
||||
|
||||
current := root
|
||||
pos := 0
|
||||
paramCount := uint8(0)
|
||||
|
||||
for {
|
||||
seg, newPos, more := readSegment(urlPath, pos)
|
||||
if seg == "" {
|
||||
break
|
||||
}
|
||||
|
||||
isDyn := len(seg) > 2 && seg[0] == '[' && seg[len(seg)-1] == ']'
|
||||
isWC := len(seg) > 0 && seg[0] == '*'
|
||||
|
||||
if isWC && more {
|
||||
return errors.New("wildcard must be the last segment")
|
||||
}
|
||||
|
||||
if isDyn || isWC {
|
||||
paramCount++
|
||||
}
|
||||
|
||||
// Find or create child
|
||||
var child *node
|
||||
for _, c := range current.children {
|
||||
if c.segment == seg {
|
||||
child = c
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if child == nil {
|
||||
child = &node{
|
||||
segment: seg,
|
||||
isDynamic: isDyn,
|
||||
isWildcard: isWC,
|
||||
}
|
||||
current.children = append(current.children, child)
|
||||
}
|
||||
|
||||
if child.maxParams < paramCount {
|
||||
child.maxParams = paramCount
|
||||
}
|
||||
|
||||
current = child
|
||||
pos = newPos
|
||||
}
|
||||
|
||||
current.bytecode = bytecode
|
||||
current.scriptPath = scriptPath
|
||||
return nil
|
||||
}
|
||||
|
||||
// readSegment extracts the next path segment
|
||||
func readSegment(path string, start int) (segment string, end int, hasMore bool) {
|
||||
if start >= len(path) {
|
||||
return "", start, false
|
||||
}
|
||||
if path[start] == '/' {
|
||||
start++
|
||||
}
|
||||
if start >= len(path) {
|
||||
return "", start, false
|
||||
}
|
||||
end = start
|
||||
for end < len(path) && path[end] != '/' {
|
||||
end++
|
||||
}
|
||||
return path[start:end], end, end < len(path)
|
||||
}
|
||||
|
||||
// Lookup finds bytecode and parameters for a method and path
|
||||
func (r *Router) Lookup(method, path string) ([]byte, *Params, bool) {
|
||||
root := r.methodNode(method)
|
||||
if root == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
if path == "/" {
|
||||
if root.bytecode != nil {
|
||||
return root.bytecode, &Params{}, true
|
||||
}
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
// Prepare params buffer
|
||||
buffer := r.paramsBuffer
|
||||
if cap(buffer) < int(root.maxParams) {
|
||||
buffer = make([]string, root.maxParams)
|
||||
r.paramsBuffer = buffer
|
||||
}
|
||||
buffer = buffer[:0]
|
||||
|
||||
var keys []string
|
||||
bytecode, paramCount, found := r.match(root, path, 0, &buffer, &keys)
|
||||
if !found {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
params := &Params{
|
||||
Keys: keys[:paramCount],
|
||||
Values: buffer[:paramCount],
|
||||
}
|
||||
|
||||
return bytecode, params, true
|
||||
}
|
||||
|
||||
// match traverses the trie to find bytecode
|
||||
func (r *Router) match(current *node, path string, start int, params *[]string, keys *[]string) ([]byte, int, bool) {
|
||||
paramCount := 0
|
||||
|
||||
// Check wildcard first
|
||||
for _, c := range current.children {
|
||||
if c.isWildcard {
|
||||
rem := path[start:]
|
||||
if len(rem) > 0 && rem[0] == '/' {
|
||||
rem = rem[1:]
|
||||
}
|
||||
*params = append(*params, rem)
|
||||
*keys = append(*keys, strings.TrimPrefix(c.segment, "*"))
|
||||
return c.bytecode, 1, c.bytecode != nil
|
||||
}
|
||||
}
|
||||
|
||||
seg, pos, more := readSegment(path, start)
|
||||
if seg == "" {
|
||||
return current.bytecode, 0, current.bytecode != nil
|
||||
}
|
||||
|
||||
for _, c := range current.children {
|
||||
if c.segment == seg || c.isDynamic {
|
||||
if c.isDynamic {
|
||||
*params = append(*params, seg)
|
||||
paramName := c.segment[1 : len(c.segment)-1] // Remove [ ]
|
||||
*keys = append(*keys, paramName)
|
||||
paramCount++
|
||||
}
|
||||
|
||||
if !more {
|
||||
return c.bytecode, paramCount, c.bytecode != nil
|
||||
}
|
||||
|
||||
bytecode, nestedCount, ok := r.match(c, path, pos, params, keys)
|
||||
if ok {
|
||||
return bytecode, paramCount + nestedCount, true
|
||||
}
|
||||
|
||||
// Backtrack on failure
|
||||
if c.isDynamic {
|
||||
*params = (*params)[:len(*params)-1]
|
||||
*keys = (*keys)[:len(*keys)-1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, 0, false
|
||||
}
|
||||
|
||||
// GetBytecode gets cached bytecode by script path
|
||||
func (r *Router) GetBytecode(scriptPath string) []byte {
|
||||
cacheKey := hashString(scriptPath)
|
||||
return r.bytecodeCache.Get(nil, uint64ToBytes(cacheKey))
|
||||
}
|
||||
|
||||
// Refresh rebuilds the router
|
||||
func (r *Router) Refresh(changes []watchers.FileChange) error {
|
||||
r.get = &node{}
|
||||
r.post = &node{}
|
||||
r.put = &node{}
|
||||
r.patch = &node{}
|
||||
r.delete = &node{}
|
||||
r.middlewareFiles = make(map[string][]string)
|
||||
r.bytecodeCache.Reset()
|
||||
return r.buildRoutes()
|
||||
}
|
||||
|
||||
// Close cleans up resources
|
||||
func (r *Router) Close() {
|
||||
r.compileMu.Lock()
|
||||
if r.compileState != nil {
|
||||
r.compileState.Close()
|
||||
r.compileState = nil
|
||||
}
|
||||
r.compileMu.Unlock()
|
||||
}
|
||||
|
||||
// Helper functions from cache.go
|
||||
func hashString(s string) uint64 {
|
||||
h := uint64(5381)
|
||||
for i := 0; i < len(s); i++ {
|
||||
h = ((h << 5) + h) + uint64(s[i])
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func uint64ToBytes(n uint64) []byte {
|
||||
b := make([]byte, 8)
|
||||
b[0] = byte(n)
|
||||
b[1] = byte(n >> 8)
|
||||
b[2] = byte(n >> 16)
|
||||
b[3] = byte(n >> 24)
|
||||
b[4] = byte(n >> 32)
|
||||
b[5] = byte(n >> 40)
|
||||
b[6] = byte(n >> 48)
|
||||
b[7] = byte(n >> 56)
|
||||
return b
|
||||
}
|
@ -1,318 +0,0 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func setupTestRoutes(t testing.TB) string {
|
||||
t.Helper()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create test route files
|
||||
routes := map[string]string{
|
||||
"index.lua": `return "home"`,
|
||||
"about.lua": `return "about"`,
|
||||
"api/users.lua": `return "users"`,
|
||||
"api/users/get.lua": `return "get_users"`,
|
||||
"api/users/post.lua": `return "create_user"`,
|
||||
"api/users/[id].lua": `return "user_" .. id`,
|
||||
"api/posts/[slug]/comments.lua": `return "comments_" .. slug`,
|
||||
"files/*path.lua": `return "file_" .. path`,
|
||||
"middleware.lua": `-- root middleware`,
|
||||
"api/middleware.lua": `-- api middleware`,
|
||||
}
|
||||
|
||||
for path, content := range routes {
|
||||
fullPath := filepath.Join(tempDir, path)
|
||||
dir := filepath.Dir(fullPath)
|
||||
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(fullPath, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
return tempDir
|
||||
}
|
||||
|
||||
func TestRouterBasicFunctionality(t *testing.T) {
|
||||
routesDir := setupTestRoutes(t)
|
||||
router, err := New(routesDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer router.Close()
|
||||
|
||||
tests := []struct {
|
||||
method string
|
||||
path string
|
||||
expected bool
|
||||
params map[string]string
|
||||
}{
|
||||
{"GET", "/", true, nil},
|
||||
{"GET", "/about", true, nil},
|
||||
{"GET", "/api/users", true, nil},
|
||||
{"GET", "/api/users", true, nil},
|
||||
{"POST", "/api/users", true, nil},
|
||||
{"GET", "/api/users/123", true, map[string]string{"id": "123"}},
|
||||
{"GET", "/api/posts/hello-world/comments", true, map[string]string{"slug": "hello-world"}},
|
||||
{"GET", "/files/docs/readme.txt", true, map[string]string{"path": "docs/readme.txt"}},
|
||||
{"GET", "/nonexistent", false, nil},
|
||||
{"DELETE", "/api/users", false, nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.method+"_"+tt.path, func(t *testing.T) {
|
||||
bytecode, params, found := router.Lookup(tt.method, tt.path)
|
||||
|
||||
if found != tt.expected {
|
||||
t.Errorf("expected found=%v, got %v", tt.expected, found)
|
||||
}
|
||||
|
||||
if tt.expected {
|
||||
if bytecode == nil {
|
||||
t.Error("expected bytecode, got nil")
|
||||
}
|
||||
|
||||
if tt.params != nil {
|
||||
for key, expectedValue := range tt.params {
|
||||
if actualValue := params.Get(key); actualValue != expectedValue {
|
||||
t.Errorf("param %s: expected %s, got %s", key, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterParamsStruct(t *testing.T) {
|
||||
params := &Params{
|
||||
Keys: []string{"id", "slug"},
|
||||
Values: []string{"123", "hello"},
|
||||
}
|
||||
|
||||
if params.Get("id") != "123" {
|
||||
t.Errorf("expected '123', got '%s'", params.Get("id"))
|
||||
}
|
||||
|
||||
if params.Get("slug") != "hello" {
|
||||
t.Errorf("expected 'hello', got '%s'", params.Get("slug"))
|
||||
}
|
||||
|
||||
if params.Get("missing") != "" {
|
||||
t.Errorf("expected empty string for missing param, got '%s'", params.Get("missing"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterMethodNodes(t *testing.T) {
|
||||
routesDir := setupTestRoutes(t)
|
||||
router, err := New(routesDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer router.Close()
|
||||
|
||||
// Test that different methods work independently
|
||||
_, _, foundGet := router.Lookup("GET", "/api/users")
|
||||
_, _, foundPost := router.Lookup("POST", "/api/users")
|
||||
_, _, foundPut := router.Lookup("PUT", "/api/users")
|
||||
|
||||
if !foundGet {
|
||||
t.Error("GET /api/users should be found")
|
||||
}
|
||||
if !foundPost {
|
||||
t.Error("POST /api/users should be found")
|
||||
}
|
||||
if foundPut {
|
||||
t.Error("PUT /api/users should not be found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterWildcardValidation(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create invalid wildcard route (not at end)
|
||||
invalidPath := filepath.Join(tempDir, "bad/*path/more.lua")
|
||||
if err := os.MkdirAll(filepath.Dir(invalidPath), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(invalidPath, []byte(`return "bad"`), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := New(tempDir)
|
||||
if err == nil {
|
||||
t.Error("expected error for wildcard not at end")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLookupStatic(b *testing.B) {
|
||||
routesDir := setupTestRoutes(b)
|
||||
router, err := New(routesDir)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer router.Close()
|
||||
|
||||
method := "GET"
|
||||
path := "/api/users"
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = router.Lookup(method, path)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLookupDynamic(b *testing.B) {
|
||||
routesDir := setupTestRoutes(b)
|
||||
router, err := New(routesDir)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer router.Close()
|
||||
|
||||
method := "GET"
|
||||
path := "/api/users/12345"
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = router.Lookup(method, path)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLookupWildcard(b *testing.B) {
|
||||
routesDir := setupTestRoutes(b)
|
||||
router, err := New(routesDir)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer router.Close()
|
||||
|
||||
method := "GET"
|
||||
path := "/files/docs/deep/nested/file.txt"
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = router.Lookup(method, path)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLookupComplex(b *testing.B) {
|
||||
routesDir := setupTestRoutes(b)
|
||||
router, err := New(routesDir)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer router.Close()
|
||||
|
||||
method := "GET"
|
||||
path := "/api/posts/my-blog-post-title/comments"
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = router.Lookup(method, path)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLookupNotFound(b *testing.B) {
|
||||
routesDir := setupTestRoutes(b)
|
||||
router, err := New(routesDir)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer router.Close()
|
||||
|
||||
method := "GET"
|
||||
path := "/this/path/does/not/exist"
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = router.Lookup(method, path)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLookupMixed(b *testing.B) {
|
||||
routesDir := setupTestRoutes(b)
|
||||
router, err := New(routesDir)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer router.Close()
|
||||
|
||||
paths := []string{
|
||||
"/",
|
||||
"/about",
|
||||
"/api/users",
|
||||
"/api/users/123",
|
||||
"/api/posts/hello/comments",
|
||||
"/files/document.pdf",
|
||||
"/nonexistent",
|
||||
}
|
||||
method := "GET"
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
path := paths[i%len(paths)]
|
||||
_, _, _ = router.Lookup(method, path)
|
||||
}
|
||||
}
|
||||
|
||||
// Comparison benchmarks for string vs byte slice performance
|
||||
func BenchmarkLookupStringConversion(b *testing.B) {
|
||||
routesDir := setupTestRoutes(b)
|
||||
router, err := New(routesDir)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer router.Close()
|
||||
|
||||
methodStr := "GET"
|
||||
pathStr := "/api/users/12345"
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Direct string usage
|
||||
_, _, _ = router.Lookup(methodStr, pathStr)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLookupPreallocated(b *testing.B) {
|
||||
routesDir := setupTestRoutes(b)
|
||||
router, err := New(routesDir)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer router.Close()
|
||||
|
||||
// Pre-allocated strings (optimal case)
|
||||
method := "GET"
|
||||
path := "/api/users/12345"
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = router.Lookup(method, path)
|
||||
}
|
||||
}
|
@ -1,58 +0,0 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Generic interface to support different types of execution contexts for the runner
|
||||
type ExecutionContext interface {
|
||||
Get(key string) any
|
||||
Set(key string, value any)
|
||||
ToMap() map[string]any
|
||||
Release()
|
||||
}
|
||||
|
||||
// This is a generic context that satisfies the runner's ExecutionContext interface
|
||||
type Context struct {
|
||||
Values map[string]any // Any data we want to pass to the state's global ctx table.
|
||||
}
|
||||
|
||||
// Context pool to reduce allocations
|
||||
var contextPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &Context{
|
||||
Values: make(map[string]any, 32),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// Gets a new context from the pool
|
||||
func NewContext() *Context {
|
||||
ctx := contextPool.Get().(*Context)
|
||||
return ctx
|
||||
}
|
||||
|
||||
// Release returns the context to the pool after clearing its values
|
||||
func (c *Context) Release() {
|
||||
// Clear all values to prevent data leakage
|
||||
for k := range c.Values {
|
||||
delete(c.Values, k)
|
||||
}
|
||||
|
||||
contextPool.Put(c)
|
||||
}
|
||||
|
||||
// Set adds a value to the context
|
||||
func (c *Context) Set(key string, value any) {
|
||||
c.Values[key] = value
|
||||
}
|
||||
|
||||
// Get retrieves a value from the context
|
||||
func (c *Context) Get(key string) any {
|
||||
return c.Values[key]
|
||||
}
|
||||
|
||||
// We can just return the Values map as it's already g2g for Lua
|
||||
func (c *Context) ToMap() map[string]any {
|
||||
return c.Values
|
||||
}
|
130
runner/embed.go
130
runner/embed.go
@ -1,130 +0,0 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"Moonshark/logger"
|
||||
_ "embed"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
//go:embed lua/sandbox.lua
|
||||
var sandboxLuaCode string
|
||||
|
||||
//go:embed lua/json.lua
|
||||
var jsonLuaCode string
|
||||
|
||||
//go:embed lua/sqlite.lua
|
||||
var sqliteLuaCode string
|
||||
|
||||
//go:embed lua/fs.lua
|
||||
var fsLuaCode string
|
||||
|
||||
//go:embed lua/util.lua
|
||||
var utilLuaCode string
|
||||
|
||||
//go:embed lua/string.lua
|
||||
var stringLuaCode string
|
||||
|
||||
//go:embed lua/table.lua
|
||||
var tableLuaCode string
|
||||
|
||||
//go:embed lua/crypto.lua
|
||||
var cryptoLuaCode string
|
||||
|
||||
//go:embed lua/time.lua
|
||||
var timeLuaCode string
|
||||
|
||||
//go:embed lua/math.lua
|
||||
var mathLuaCode string
|
||||
|
||||
//go:embed lua/env.lua
|
||||
var envLuaCode string
|
||||
|
||||
//go:embed lua/http.lua
|
||||
var httpLuaCode string
|
||||
|
||||
//go:embed lua/cookie.lua
|
||||
var cookieLuaCode string
|
||||
|
||||
//go:embed lua/csrf.lua
|
||||
var csrfLuaCode string
|
||||
|
||||
//go:embed lua/render.lua
|
||||
var renderLuaCode string
|
||||
|
||||
//go:embed lua/session.lua
|
||||
var sessionLuaCode string
|
||||
|
||||
//go:embed lua/timestamp.lua
|
||||
var timestampLuaCode string
|
||||
|
||||
// Module represents a Lua module to load
|
||||
type Module struct {
|
||||
name string
|
||||
code string
|
||||
global bool // true if module defines globals, false if it returns a table
|
||||
}
|
||||
|
||||
var modules = []Module{
|
||||
{"http", httpLuaCode, true},
|
||||
{"string", stringLuaCode, false},
|
||||
{"table", tableLuaCode, false},
|
||||
{"util", utilLuaCode, true},
|
||||
{"cookie", cookieLuaCode, true},
|
||||
{"session", sessionLuaCode, true},
|
||||
{"csrf", csrfLuaCode, true},
|
||||
{"render", renderLuaCode, true},
|
||||
{"json", jsonLuaCode, true},
|
||||
{"fs", fsLuaCode, true},
|
||||
{"crypto", cryptoLuaCode, true},
|
||||
{"time", timeLuaCode, false},
|
||||
{"math", mathLuaCode, false},
|
||||
{"env", envLuaCode, true},
|
||||
{"sqlite", sqliteLuaCode, true},
|
||||
{"timestamp", timestampLuaCode, false},
|
||||
}
|
||||
|
||||
// loadModule loads a single module into the Lua state
|
||||
func loadModule(state *luajit.State, m Module) error {
|
||||
if m.global {
|
||||
// Module defines globals directly, just execute it
|
||||
return state.DoString(m.code)
|
||||
}
|
||||
|
||||
// Module returns a table, capture it and set as global
|
||||
if err := state.LoadString(m.code); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.Call(0, 1); err != nil {
|
||||
return err
|
||||
}
|
||||
state.SetGlobal(m.name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadSandboxIntoState loads all modules and sandbox into a Lua state
|
||||
func loadSandboxIntoState(state *luajit.State, verbose bool) error {
|
||||
// Load all utility modules
|
||||
for _, module := range modules {
|
||||
if err := loadModule(state, module); err != nil {
|
||||
if verbose {
|
||||
logger.Errorf("Failed to load %s module: %v", module.name, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
if verbose {
|
||||
logger.Debugf("Loaded %s.lua", module.name)
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize module-specific globals
|
||||
if err := state.DoString(`__active_sqlite_connections = {}`); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Load sandbox last - defines __execute and core functions
|
||||
if verbose {
|
||||
logger.Debugf("Loading sandbox.lua")
|
||||
}
|
||||
return state.DoString(sandboxLuaCode)
|
||||
}
|
@ -1,136 +0,0 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"Moonshark/router"
|
||||
"Moonshark/runner/lualibs"
|
||||
"Moonshark/sessions"
|
||||
"Moonshark/utils"
|
||||
"sync"
|
||||
|
||||
"maps"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// A prebuilt, ready-to-go context for HTTP requests to the runner.
|
||||
type HTTPContext struct {
|
||||
Values map[string]any // Contains all context data for Lua
|
||||
|
||||
// Separate maps for efficient access during context building
|
||||
headers map[string]string
|
||||
cookies map[string]string
|
||||
query map[string]string
|
||||
params map[string]string
|
||||
form map[string]any
|
||||
session map[string]any
|
||||
env map[string]any
|
||||
}
|
||||
|
||||
// HTTP context pool to reduce allocations
|
||||
var httpContextPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &HTTPContext{
|
||||
Values: make(map[string]any, 32),
|
||||
headers: make(map[string]string, 16),
|
||||
cookies: make(map[string]string, 8),
|
||||
query: make(map[string]string, 8),
|
||||
params: make(map[string]string, 4),
|
||||
form: make(map[string]any, 8),
|
||||
session: make(map[string]any, 4),
|
||||
env: make(map[string]any, 16),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// Get a clean HTTP context from the pool and build it up with an HTTP request, router params and session data
|
||||
func NewHTTPContext(httpCtx *fasthttp.RequestCtx, params *router.Params, session *sessions.Session) *HTTPContext {
|
||||
ctx := httpContextPool.Get().(*HTTPContext)
|
||||
|
||||
// Extract headers
|
||||
httpCtx.Request.Header.VisitAll(func(key, value []byte) {
|
||||
ctx.headers[string(key)] = string(value)
|
||||
})
|
||||
|
||||
// Extract cookies
|
||||
httpCtx.Request.Header.VisitAllCookie(func(key, value []byte) {
|
||||
ctx.cookies[string(key)] = string(value)
|
||||
})
|
||||
|
||||
// Extract query params
|
||||
httpCtx.QueryArgs().VisitAll(func(key, value []byte) {
|
||||
ctx.query[string(key)] = string(value)
|
||||
})
|
||||
|
||||
// Extract route parameters
|
||||
if params != nil {
|
||||
for i := range min(len(params.Keys), len(params.Values)) {
|
||||
ctx.params[params.Keys[i]] = params.Values[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Extract form data if present
|
||||
if httpCtx.IsPost() || httpCtx.IsPut() || httpCtx.IsPatch() {
|
||||
if form, err := utils.ParseForm(httpCtx); err == nil {
|
||||
maps.Copy(ctx.form, form)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract session data
|
||||
session.AdvanceFlash()
|
||||
ctx.session["id"] = session.ID
|
||||
if session.IsEmpty() {
|
||||
ctx.session["data"] = emptyMap
|
||||
ctx.session["flash"] = emptyMap
|
||||
} else {
|
||||
ctx.session["data"] = session.GetAll()
|
||||
ctx.session["flash"] = session.GetAllFlash()
|
||||
}
|
||||
|
||||
// Add environment vars
|
||||
if envMgr := lualibs.GetGlobalEnvManager(); envMgr != nil {
|
||||
maps.Copy(ctx.env, envMgr.GetAll())
|
||||
}
|
||||
|
||||
// Populate Values with all context data
|
||||
ctx.Values["method"] = string(httpCtx.Method())
|
||||
ctx.Values["path"] = string(httpCtx.Path())
|
||||
ctx.Values["host"] = string(httpCtx.Host())
|
||||
ctx.Values["headers"] = ctx.headers
|
||||
ctx.Values["cookies"] = ctx.cookies
|
||||
ctx.Values["query"] = ctx.query
|
||||
ctx.Values["params"] = ctx.params
|
||||
ctx.Values["form"] = ctx.form
|
||||
ctx.Values["session"] = ctx.session
|
||||
ctx.Values["env"] = ctx.env
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
// Clear out all the request data from the context and give it back to the pool.
|
||||
func (c *HTTPContext) Release() {
|
||||
clear(c.Values)
|
||||
clear(c.headers)
|
||||
clear(c.cookies)
|
||||
clear(c.query)
|
||||
clear(c.params)
|
||||
clear(c.form)
|
||||
clear(c.session)
|
||||
clear(c.env)
|
||||
|
||||
httpContextPool.Put(c)
|
||||
}
|
||||
|
||||
// Add a value to the extras section
|
||||
func (c *HTTPContext) Set(key string, value any) {
|
||||
c.Values[key] = value
|
||||
}
|
||||
|
||||
// Get a value from the context
|
||||
func (c *HTTPContext) Get(key string) any {
|
||||
return c.Values[key]
|
||||
}
|
||||
|
||||
// Returns the Values map directly - zero overhead
|
||||
func (c *HTTPContext) ToMap() map[string]any {
|
||||
return c.Values
|
||||
}
|
@ -1,26 +0,0 @@
|
||||
-- cookie.lua
|
||||
|
||||
function cookie_set(name, value, options)
|
||||
__response.cookies = __response.cookies or {}
|
||||
local opts = options or {}
|
||||
local cookie = {
|
||||
name = name,
|
||||
value = value or "",
|
||||
path = opts.path or "/",
|
||||
domain = opts.domain,
|
||||
secure = opts.secure ~= false,
|
||||
http_only = opts.http_only ~= false
|
||||
}
|
||||
if opts.expires and opts.expires > 0 then
|
||||
cookie.max_age = opts.expires
|
||||
end
|
||||
table.insert(__response.cookies, cookie)
|
||||
end
|
||||
|
||||
function cookie_get(name)
|
||||
return __ctx.cookies and __ctx.cookies[name]
|
||||
end
|
||||
|
||||
function cookie_delete(name, path, domain)
|
||||
return cookie_set(name, "", {expires = -1, path = path or "/", domain = domain})
|
||||
end
|
@ -1,140 +0,0 @@
|
||||
-- crypto.lua
|
||||
|
||||
-- ======================================================================
|
||||
-- HASHING FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Generate hash digest using various algorithms
|
||||
-- Algorithms: md5, sha1, sha256, sha512
|
||||
-- Formats: hex (default), binary
|
||||
function hash(data, algorithm, format)
|
||||
if type(data) ~= "string" then
|
||||
error("hash: data must be a string", 2)
|
||||
end
|
||||
|
||||
algorithm = algorithm or "sha256"
|
||||
format = format or "hex"
|
||||
|
||||
return __crypto_hash(data, algorithm, format)
|
||||
end
|
||||
|
||||
function md5(data, format)
|
||||
return hash(data, "md5", format)
|
||||
end
|
||||
|
||||
function sha1(data, format)
|
||||
return hash(data, "sha1", format)
|
||||
end
|
||||
|
||||
function sha256(data, format)
|
||||
return hash(data, "sha256", format)
|
||||
end
|
||||
|
||||
function sha512(data, format)
|
||||
return hash(data, "sha512", format)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- HMAC FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Generate HMAC using various algorithms
|
||||
-- Algorithms: md5, sha1, sha256, sha512
|
||||
-- Formats: hex (default), binary
|
||||
function hmac(data, key, algorithm, format)
|
||||
if type(data) ~= "string" then
|
||||
error("hmac: data must be a string", 2)
|
||||
end
|
||||
|
||||
if type(key) ~= "string" then
|
||||
error("hmac: key must be a string", 2)
|
||||
end
|
||||
|
||||
algorithm = algorithm or "sha256"
|
||||
format = format or "hex"
|
||||
|
||||
return __crypto_hmac(data, key, algorithm, format)
|
||||
end
|
||||
|
||||
function hmac_md5(data, key, format)
|
||||
return hmac(data, key, "md5", format)
|
||||
end
|
||||
|
||||
function hmac_sha1(data, key, format)
|
||||
return hmac(data, key, "sha1", format)
|
||||
end
|
||||
|
||||
function hmac_sha256(data, key, format)
|
||||
return hmac(data, key, "sha256", format)
|
||||
end
|
||||
|
||||
function hmac_sha512(data, key, format)
|
||||
return hmac(data, key, "sha512", format)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- RANDOM FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Generate random bytes
|
||||
-- Formats: binary (default), hex
|
||||
function random_bytes(length, secure, format)
|
||||
if type(length) ~= "number" or length <= 0 then
|
||||
error("random_bytes: length must be positive", 2)
|
||||
end
|
||||
|
||||
secure = secure ~= false -- Default to secure
|
||||
format = format or "binary"
|
||||
|
||||
return __crypto_random_bytes(length, secure, format)
|
||||
end
|
||||
|
||||
-- Generate random integer in range [min, max]
|
||||
function random_int(min, max, secure)
|
||||
if type(min) ~= "number" or type(max) ~= "number" then
|
||||
error("random_int: min and max must be numbers", 2)
|
||||
end
|
||||
|
||||
if max <= min then
|
||||
error("random_int: max must be greater than min", 2)
|
||||
end
|
||||
|
||||
secure = secure ~= false -- Default to secure
|
||||
|
||||
return __crypto_random_int(min, max, secure)
|
||||
end
|
||||
|
||||
-- Generate random string of specified length
|
||||
function random_string(length, charset, secure)
|
||||
if type(length) ~= "number" or length <= 0 then
|
||||
error("random_string: length must be positive", 2)
|
||||
end
|
||||
|
||||
secure = secure ~= false -- Default to secure
|
||||
|
||||
-- Default character set: alphanumeric
|
||||
charset = charset or "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
if type(charset) ~= "string" or #charset == 0 then
|
||||
error("random_string: charset must be non-empty", 2)
|
||||
end
|
||||
|
||||
local result = ""
|
||||
local charset_length = #charset
|
||||
|
||||
for i = 1, length do
|
||||
local index = random_int(1, charset_length, secure)
|
||||
result = result .. charset:sub(index, index)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- UUID FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Generate random UUID (v4)
|
||||
function uuid()
|
||||
return __crypto_uuid()
|
||||
end
|
@ -1,33 +0,0 @@
|
||||
-- csrf.lua
|
||||
|
||||
function csrf_generate()
|
||||
local token = generate_token(32)
|
||||
session_set("_csrf_token", token)
|
||||
return token
|
||||
end
|
||||
|
||||
function csrf_field()
|
||||
local token = session_get("_csrf_token")
|
||||
if not token then
|
||||
token = csrf_generate()
|
||||
end
|
||||
return string.format('<input type="hidden" name="_csrf_token" value="%s" />',
|
||||
html_special_chars(token))
|
||||
end
|
||||
|
||||
function csrf_validate()
|
||||
local token = __ctx.session and __ctx.session.data and __ctx.session.data["_csrf_token"]
|
||||
if not token then
|
||||
__response.status = 403
|
||||
coroutine.yield("__EXIT__")
|
||||
end
|
||||
|
||||
local request_token = (__ctx._request_form and __ctx._request_form._csrf_token) or
|
||||
(__ctx._request_headers and (__ctx._request_headers["x-csrf-token"] or __ctx._request_headers["csrf-token"]))
|
||||
|
||||
if not request_token or request_token ~= token then
|
||||
__response.status = 403
|
||||
coroutine.yield("__EXIT__")
|
||||
end
|
||||
return true
|
||||
end
|
@ -1,89 +0,0 @@
|
||||
-- env.lua
|
||||
|
||||
-- Get an environment variable with a default value
|
||||
function env_get(key, default_value)
|
||||
if type(key) ~= "string" then
|
||||
error("env_get: key must be a string")
|
||||
end
|
||||
|
||||
-- Check context for environment variables
|
||||
if __ctx and __ctx.env and __ctx.env[key] ~= nil then
|
||||
return __ctx.env[key]
|
||||
end
|
||||
|
||||
return default_value
|
||||
end
|
||||
|
||||
-- Set an environment variable
|
||||
function env_set(key, value)
|
||||
if type(key) ~= "string" then
|
||||
error("env_set: key must be a string")
|
||||
end
|
||||
|
||||
-- Update context immediately for future reads
|
||||
if __ctx then
|
||||
__ctx.env = __ctx.env or {}
|
||||
__ctx.env[key] = value
|
||||
end
|
||||
|
||||
-- Persist to Go backend
|
||||
return __env_set(key, value)
|
||||
end
|
||||
|
||||
-- Get all environment variables as a table
|
||||
function env_get_all()
|
||||
-- Return context table directly if available
|
||||
if __ctx and __ctx.env then
|
||||
local copy = {}
|
||||
for k, v in pairs(__ctx.env) do
|
||||
copy[k] = v
|
||||
end
|
||||
return copy
|
||||
end
|
||||
|
||||
-- Fallback to Go call
|
||||
return __env_get_all()
|
||||
end
|
||||
|
||||
-- Check if an environment variable exists
|
||||
function env_exists(key)
|
||||
if type(key) ~= "string" then
|
||||
error("env_exists: key must be a string")
|
||||
end
|
||||
|
||||
-- Check context first
|
||||
if __ctx and __ctx.env then
|
||||
return __ctx.env[key] ~= nil
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
-- Set multiple environment variables from a table
|
||||
function env_set_many(vars)
|
||||
if type(vars) ~= "table" then
|
||||
error("env_set_many: vars must be a table")
|
||||
end
|
||||
|
||||
if __ctx then
|
||||
__ctx.env = __ctx.env or {}
|
||||
end
|
||||
|
||||
local success = true
|
||||
for key, value in pairs(vars) do
|
||||
if type(key) == "string" then
|
||||
-- Update context
|
||||
if __ctx and __ctx.env then
|
||||
__ctx.env[key] = value
|
||||
end
|
||||
-- Persist to Go
|
||||
if not __env_set(key, value) then
|
||||
success = false
|
||||
end
|
||||
else
|
||||
error("env_set_many: all keys must be strings")
|
||||
end
|
||||
end
|
||||
|
||||
return success
|
||||
end
|
@ -1,136 +0,0 @@
|
||||
-- fs.lua
|
||||
|
||||
function fs_read(path)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_read: path must be a string", 2)
|
||||
end
|
||||
return __fs_read_file(path)
|
||||
end
|
||||
|
||||
function fs_write(path, content)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_write: path must be a string", 2)
|
||||
end
|
||||
if type(content) ~= "string" then
|
||||
error("fs_write: content must be a string", 2)
|
||||
end
|
||||
return __fs_write_file(path, content)
|
||||
end
|
||||
|
||||
function fs_append(path, content)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_append: path must be a string", 2)
|
||||
end
|
||||
if type(content) ~= "string" then
|
||||
error("fs_append: content must be a string", 2)
|
||||
end
|
||||
return __fs_append_file(path, content)
|
||||
end
|
||||
|
||||
function fs_exists(path)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_exists: path must be a string", 2)
|
||||
end
|
||||
return __fs_exists(path)
|
||||
end
|
||||
|
||||
function fs_remove(path)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_remove: path must be a string", 2)
|
||||
end
|
||||
return __fs_remove_file(path)
|
||||
end
|
||||
|
||||
function fs_info(path)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_info: path must be a string", 2)
|
||||
end
|
||||
local info = __fs_get_info(path)
|
||||
|
||||
-- Convert the Unix timestamp to a readable date
|
||||
if info and info.mod_time then
|
||||
info.mod_time_str = os.date("%Y-%m-%d %H:%M:%S", info.mod_time)
|
||||
end
|
||||
|
||||
return info
|
||||
end
|
||||
|
||||
-- Directory Operations
|
||||
function fs_mkdir(path, mode)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_mkdir: path must be a string", 2)
|
||||
end
|
||||
mode = mode or 0755
|
||||
return __fs_make_dir(path, mode)
|
||||
end
|
||||
|
||||
function fs_ls(path)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_ls: path must be a string", 2)
|
||||
end
|
||||
return __fs_list_dir(path)
|
||||
end
|
||||
|
||||
function fs_rmdir(path, recursive)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_rmdir: path must be a string", 2)
|
||||
end
|
||||
recursive = recursive or false
|
||||
return __fs_remove_dir(path, recursive)
|
||||
end
|
||||
|
||||
-- Path Operations
|
||||
function fs_join_paths(...)
|
||||
return __fs_join_paths(...)
|
||||
end
|
||||
|
||||
function fs_dir_name(path)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_dir_name: path must be a string", 2)
|
||||
end
|
||||
return __fs_dir_name(path)
|
||||
end
|
||||
|
||||
function fs_base_name(path)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_base_name: path must be a string", 2)
|
||||
end
|
||||
return __fs_base_name(path)
|
||||
end
|
||||
|
||||
function fs_extension(path)
|
||||
if type(path) ~= "string" then
|
||||
error("fs_extension: path must be a string", 2)
|
||||
end
|
||||
return __fs_extension(path)
|
||||
end
|
||||
|
||||
-- Utility Functions
|
||||
function fs_read_json(path)
|
||||
local content = fs_read(path)
|
||||
if not content then
|
||||
return nil, "Could not read file"
|
||||
end
|
||||
|
||||
local ok, result = pcall(json.decode, content)
|
||||
if not ok then
|
||||
return nil, "Invalid JSON: " .. tostring(result)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
function fs_write_json(path, data, pretty)
|
||||
if type(data) ~= "table" then
|
||||
error("fs_write_json: data must be a table", 2)
|
||||
end
|
||||
|
||||
local content
|
||||
if pretty then
|
||||
content = json.pretty_print(data)
|
||||
else
|
||||
content = json.encode(data)
|
||||
end
|
||||
|
||||
return fs_write(path, content)
|
||||
end
|
@ -1,72 +0,0 @@
|
||||
-- http.lua
|
||||
|
||||
function http_set_status(code)
|
||||
__response.status = code
|
||||
end
|
||||
|
||||
function http_set_header(name, value)
|
||||
__response.headers = __response.headers or {}
|
||||
__response.headers[name] = value
|
||||
end
|
||||
|
||||
function http_set_content_type(ct)
|
||||
__response.headers = __response.headers or {}
|
||||
__response.headers["Content-Type"] = ct
|
||||
end
|
||||
|
||||
function http_set_metadata(key, value)
|
||||
__response.metadata = __response.metadata or {}
|
||||
__response.metadata[key] = value
|
||||
end
|
||||
|
||||
function http_redirect(url, status)
|
||||
__response.status = status or 302
|
||||
__response.headers = __response.headers or {}
|
||||
__response.headers["Location"] = url
|
||||
coroutine.yield("__EXIT__")
|
||||
end
|
||||
|
||||
function send_html(content)
|
||||
http_set_content_type("text/html")
|
||||
return content
|
||||
end
|
||||
|
||||
function send_json(content)
|
||||
http_set_content_type("application/json")
|
||||
return content
|
||||
end
|
||||
|
||||
function send_text(content)
|
||||
http_set_content_type("text/plain")
|
||||
return content
|
||||
end
|
||||
|
||||
function send_xml(content)
|
||||
http_set_content_type("application/xml")
|
||||
return content
|
||||
end
|
||||
|
||||
function send_javascript(content)
|
||||
http_set_content_type("application/javascript")
|
||||
return content
|
||||
end
|
||||
|
||||
function send_css(content)
|
||||
http_set_content_type("text/css")
|
||||
return content
|
||||
end
|
||||
|
||||
function send_svg(content)
|
||||
http_set_content_type("image/svg+xml")
|
||||
return content
|
||||
end
|
||||
|
||||
function send_csv(content)
|
||||
http_set_content_type("text/csv")
|
||||
return content
|
||||
end
|
||||
|
||||
function send_binary(content, mime_type)
|
||||
http_set_content_type(mime_type or "application/octet-stream")
|
||||
return content
|
||||
end
|
@ -1,421 +0,0 @@
|
||||
-- json.lua
|
||||
|
||||
local escape_chars = {
|
||||
['"'] = '\\"', ['\\'] = '\\\\',
|
||||
['\n'] = '\\n', ['\r'] = '\\r', ['\t'] = '\\t'
|
||||
}
|
||||
|
||||
function json_go_encode(value)
|
||||
return __json_marshal(value)
|
||||
end
|
||||
|
||||
function json_go_decode(str)
|
||||
if type(str) ~= "string" then
|
||||
error("json_decode: expected string, got " .. type(str), 2)
|
||||
end
|
||||
return __json_unmarshal(str)
|
||||
end
|
||||
|
||||
function json_encode(data)
|
||||
local t = type(data)
|
||||
|
||||
if t == "nil" then return "null" end
|
||||
if t == "boolean" then return data and "true" or "false" end
|
||||
if t == "number" then return tostring(data) end
|
||||
|
||||
if t == "string" then
|
||||
return '"' .. data:gsub('[\\"\n\r\t]', escape_chars) .. '"'
|
||||
end
|
||||
|
||||
if t == "table" then
|
||||
local isArray = true
|
||||
local count = 0
|
||||
|
||||
-- Check if it's an array in one pass
|
||||
for k, _ in pairs(data) do
|
||||
count = count + 1
|
||||
if type(k) ~= "number" or k ~= count or k < 1 then
|
||||
isArray = false
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if isArray then
|
||||
local result = {}
|
||||
for i = 1, count do
|
||||
result[i] = json_encode(data[i])
|
||||
end
|
||||
return "[" .. table.concat(result, ",") .. "]"
|
||||
else
|
||||
local result = {}
|
||||
local index = 1
|
||||
for k, v in pairs(data) do
|
||||
if type(k) == "string" and type(v) ~= "function" and type(v) ~= "userdata" then
|
||||
result[index] = json_encode(k) .. ":" .. json_encode(v)
|
||||
index = index + 1
|
||||
end
|
||||
end
|
||||
return "{" .. table.concat(result, ",") .. "}"
|
||||
end
|
||||
end
|
||||
|
||||
return "null" -- Unsupported type
|
||||
end
|
||||
|
||||
function json_decode(data)
|
||||
local pos = 1
|
||||
local len = #data
|
||||
|
||||
-- Pre-compute byte values
|
||||
local b_space = string.byte(' ')
|
||||
local b_tab = string.byte('\t')
|
||||
local b_cr = string.byte('\r')
|
||||
local b_lf = string.byte('\n')
|
||||
local b_quote = string.byte('"')
|
||||
local b_backslash = string.byte('\\')
|
||||
local b_slash = string.byte('/')
|
||||
local b_lcurly = string.byte('{')
|
||||
local b_rcurly = string.byte('}')
|
||||
local b_lbracket = string.byte('[')
|
||||
local b_rbracket = string.byte(']')
|
||||
local b_colon = string.byte(':')
|
||||
local b_comma = string.byte(',')
|
||||
local b_0 = string.byte('0')
|
||||
local b_9 = string.byte('9')
|
||||
local b_minus = string.byte('-')
|
||||
local b_plus = string.byte('+')
|
||||
local b_dot = string.byte('.')
|
||||
local b_e = string.byte('e')
|
||||
local b_E = string.byte('E')
|
||||
|
||||
-- Skip whitespace more efficiently
|
||||
local function skip()
|
||||
local b
|
||||
while pos <= len do
|
||||
b = data:byte(pos)
|
||||
if b > b_space or (b ~= b_space and b ~= b_tab and b ~= b_cr and b ~= b_lf) then
|
||||
break
|
||||
end
|
||||
pos = pos + 1
|
||||
end
|
||||
end
|
||||
|
||||
-- Forward declarations
|
||||
local parse_value, parse_string, parse_number, parse_object, parse_array
|
||||
|
||||
-- Parse a string more efficiently
|
||||
parse_string = function()
|
||||
pos = pos + 1 -- Skip opening quote
|
||||
|
||||
if pos > len then
|
||||
error("Unterminated string")
|
||||
end
|
||||
|
||||
-- Use a table to build the string
|
||||
local result = {}
|
||||
local result_pos = 1
|
||||
local start = pos
|
||||
local c, b
|
||||
|
||||
while pos <= len do
|
||||
b = data:byte(pos)
|
||||
|
||||
if b == b_backslash then
|
||||
-- Add the chunk before the escape character
|
||||
if pos > start then
|
||||
result[result_pos] = data:sub(start, pos - 1)
|
||||
result_pos = result_pos + 1
|
||||
end
|
||||
|
||||
pos = pos + 1
|
||||
if pos > len then
|
||||
error("Unterminated string escape")
|
||||
end
|
||||
|
||||
c = data:byte(pos)
|
||||
if c == b_quote then
|
||||
result[result_pos] = '"'
|
||||
elseif c == b_backslash then
|
||||
result[result_pos] = '\\'
|
||||
elseif c == b_slash then
|
||||
result[result_pos] = '/'
|
||||
elseif c == string.byte('b') then
|
||||
result[result_pos] = '\b'
|
||||
elseif c == string.byte('f') then
|
||||
result[result_pos] = '\f'
|
||||
elseif c == string.byte('n') then
|
||||
result[result_pos] = '\n'
|
||||
elseif c == string.byte('r') then
|
||||
result[result_pos] = '\r'
|
||||
elseif c == string.byte('t') then
|
||||
result[result_pos] = '\t'
|
||||
else
|
||||
result[result_pos] = data:sub(pos, pos)
|
||||
end
|
||||
|
||||
result_pos = result_pos + 1
|
||||
pos = pos + 1
|
||||
start = pos
|
||||
elseif b == b_quote then
|
||||
-- Add the final chunk
|
||||
if pos > start then
|
||||
result[result_pos] = data:sub(start, pos - 1)
|
||||
result_pos = result_pos + 1
|
||||
end
|
||||
|
||||
pos = pos + 1
|
||||
return table.concat(result)
|
||||
else
|
||||
pos = pos + 1
|
||||
end
|
||||
end
|
||||
|
||||
error("Unterminated string")
|
||||
end
|
||||
|
||||
-- Parse a number more efficiently
|
||||
parse_number = function()
|
||||
local start = pos
|
||||
local b = data:byte(pos)
|
||||
|
||||
-- Skip any sign
|
||||
if b == b_minus then
|
||||
pos = pos + 1
|
||||
if pos > len then
|
||||
error("Malformed number")
|
||||
end
|
||||
b = data:byte(pos)
|
||||
end
|
||||
|
||||
-- Integer part
|
||||
if b < b_0 or b > b_9 then
|
||||
error("Malformed number")
|
||||
end
|
||||
|
||||
repeat
|
||||
pos = pos + 1
|
||||
if pos > len then break end
|
||||
b = data:byte(pos)
|
||||
until b < b_0 or b > b_9
|
||||
|
||||
-- Fractional part
|
||||
if pos <= len and b == b_dot then
|
||||
pos = pos + 1
|
||||
if pos > len or data:byte(pos) < b_0 or data:byte(pos) > b_9 then
|
||||
error("Malformed number")
|
||||
end
|
||||
|
||||
repeat
|
||||
pos = pos + 1
|
||||
if pos > len then break end
|
||||
b = data:byte(pos)
|
||||
until b < b_0 or b > b_9
|
||||
end
|
||||
|
||||
-- Exponent
|
||||
if pos <= len and (b == b_e or b == b_E) then
|
||||
pos = pos + 1
|
||||
if pos > len then
|
||||
error("Malformed number")
|
||||
end
|
||||
|
||||
b = data:byte(pos)
|
||||
if b == b_plus or b == b_minus then
|
||||
pos = pos + 1
|
||||
if pos > len then
|
||||
error("Malformed number")
|
||||
end
|
||||
b = data:byte(pos)
|
||||
end
|
||||
|
||||
if b < b_0 or b > b_9 then
|
||||
error("Malformed number")
|
||||
end
|
||||
|
||||
repeat
|
||||
pos = pos + 1
|
||||
if pos > len then break end
|
||||
b = data:byte(pos)
|
||||
until b < b_0 or b > b_9
|
||||
end
|
||||
|
||||
return tonumber(data:sub(start, pos - 1))
|
||||
end
|
||||
|
||||
-- Parse an object more efficiently
|
||||
parse_object = function()
|
||||
pos = pos + 1 -- Skip opening brace
|
||||
local obj = {}
|
||||
|
||||
skip()
|
||||
if pos <= len and data:byte(pos) == b_rcurly then
|
||||
pos = pos + 1
|
||||
return obj
|
||||
end
|
||||
|
||||
while pos <= len do
|
||||
skip()
|
||||
|
||||
if data:byte(pos) ~= b_quote then
|
||||
error("Expected string key")
|
||||
end
|
||||
|
||||
local key = parse_string()
|
||||
skip()
|
||||
|
||||
if data:byte(pos) ~= b_colon then
|
||||
error("Expected colon")
|
||||
end
|
||||
pos = pos + 1
|
||||
|
||||
obj[key] = parse_value()
|
||||
skip()
|
||||
|
||||
local b = data:byte(pos)
|
||||
if b == b_rcurly then
|
||||
pos = pos + 1
|
||||
return obj
|
||||
end
|
||||
|
||||
if b ~= b_comma then
|
||||
error("Expected comma or closing brace")
|
||||
end
|
||||
pos = pos + 1
|
||||
end
|
||||
|
||||
error("Unterminated object")
|
||||
end
|
||||
|
||||
-- Parse an array more efficiently
|
||||
parse_array = function()
|
||||
pos = pos + 1 -- Skip opening bracket
|
||||
local arr = {}
|
||||
local index = 1
|
||||
|
||||
skip()
|
||||
if pos <= len and data:byte(pos) == b_rbracket then
|
||||
pos = pos + 1
|
||||
return arr
|
||||
end
|
||||
|
||||
while pos <= len do
|
||||
arr[index] = parse_value()
|
||||
index = index + 1
|
||||
|
||||
skip()
|
||||
|
||||
local b = data:byte(pos)
|
||||
if b == b_rbracket then
|
||||
pos = pos + 1
|
||||
return arr
|
||||
end
|
||||
|
||||
if b ~= b_comma then
|
||||
error("Expected comma or closing bracket")
|
||||
end
|
||||
pos = pos + 1
|
||||
end
|
||||
|
||||
error("Unterminated array")
|
||||
end
|
||||
|
||||
-- Parse a value more efficiently
|
||||
parse_value = function()
|
||||
skip()
|
||||
|
||||
if pos > len then
|
||||
error("Unexpected end of input")
|
||||
end
|
||||
|
||||
local b = data:byte(pos)
|
||||
|
||||
if b == b_quote then
|
||||
return parse_string()
|
||||
elseif b == b_lcurly then
|
||||
return parse_object()
|
||||
elseif b == b_lbracket then
|
||||
return parse_array()
|
||||
elseif b == string.byte('n') and pos + 3 <= len and data:sub(pos, pos + 3) == "null" then
|
||||
pos = pos + 4
|
||||
return nil
|
||||
elseif b == string.byte('t') and pos + 3 <= len and data:sub(pos, pos + 3) == "true" then
|
||||
pos = pos + 4
|
||||
return true
|
||||
elseif b == string.byte('f') and pos + 4 <= len and data:sub(pos, pos + 4) == "false" then
|
||||
pos = pos + 5
|
||||
return false
|
||||
elseif b == b_minus or (b >= b_0 and b <= b_9) then
|
||||
return parse_number()
|
||||
else
|
||||
error("Unexpected character: " .. string.char(b))
|
||||
end
|
||||
end
|
||||
|
||||
skip()
|
||||
local result = parse_value()
|
||||
skip()
|
||||
|
||||
if pos <= len then
|
||||
error("Unexpected trailing characters")
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
function json_is_valid(str)
|
||||
if type(str) ~= "string" then return false end
|
||||
local status, _ = pcall(json_decode, str)
|
||||
return status
|
||||
end
|
||||
|
||||
function json_pretty_print(value)
|
||||
if type(value) == "string" then
|
||||
value = json_decode(value)
|
||||
end
|
||||
|
||||
local function stringify(val, indent, visited)
|
||||
visited = visited or {}
|
||||
indent = indent or 0
|
||||
local spaces = string.rep(" ", indent)
|
||||
|
||||
if type(val) == "table" then
|
||||
if visited[val] then return "{...}" end
|
||||
visited[val] = true
|
||||
|
||||
local isArray = true
|
||||
local i = 1
|
||||
for k in pairs(val) do
|
||||
if type(k) ~= "number" or k ~= i then
|
||||
isArray = false
|
||||
break
|
||||
end
|
||||
i = i + 1
|
||||
end
|
||||
|
||||
local result = isArray and "[\n" or "{\n"
|
||||
local first = true
|
||||
|
||||
if isArray then
|
||||
for i, v in ipairs(val) do
|
||||
if not first then result = result .. ",\n" end
|
||||
first = false
|
||||
result = result .. spaces .. " " .. stringify(v, indent + 1, visited)
|
||||
end
|
||||
else
|
||||
for k, v in pairs(val) do
|
||||
if not first then result = result .. ",\n" end
|
||||
first = false
|
||||
result = result .. spaces .. " \"" .. tostring(k) .. "\": " .. stringify(v, indent + 1, visited)
|
||||
end
|
||||
end
|
||||
|
||||
return result .. "\n" .. spaces .. (isArray and "]" or "}")
|
||||
elseif type(val) == "string" then
|
||||
return "\"" .. val:gsub('\\', '\\\\'):gsub('"', '\\"'):gsub('\n', '\\n') .. "\""
|
||||
else
|
||||
return tostring(val)
|
||||
end
|
||||
end
|
||||
|
||||
return stringify(value)
|
||||
end
|
@ -1,800 +0,0 @@
|
||||
-- math.lua
|
||||
|
||||
local math_ext = {}
|
||||
|
||||
-- Import standard math functions
|
||||
for name, func in pairs(_G.math) do
|
||||
math_ext[name] = func
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- CONSTANTS (higher precision)
|
||||
-- ======================================================================
|
||||
|
||||
math_ext.pi = 3.14159265358979323846
|
||||
math_ext.tau = 6.28318530717958647693 -- 2*pi
|
||||
math_ext.e = 2.71828182845904523536
|
||||
math_ext.phi = 1.61803398874989484820 -- Golden ratio
|
||||
math_ext.sqrt2 = 1.41421356237309504880
|
||||
math_ext.sqrt3 = 1.73205080756887729353
|
||||
math_ext.ln2 = 0.69314718055994530942
|
||||
math_ext.ln10 = 2.30258509299404568402
|
||||
math_ext.infinity = 1/0
|
||||
math_ext.nan = 0/0
|
||||
|
||||
-- ======================================================================
|
||||
-- EXTENDED FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Cube root (handles negative numbers correctly)
|
||||
function math_ext.cbrt(x)
|
||||
return x < 0 and -(-x)^(1/3) or x^(1/3)
|
||||
end
|
||||
|
||||
-- Hypotenuse of right-angled triangle
|
||||
function math_ext.hypot(x, y)
|
||||
return math.sqrt(x * x + y * y)
|
||||
end
|
||||
|
||||
-- Check if value is NaN
|
||||
function math_ext.isnan(x)
|
||||
return x ~= x
|
||||
end
|
||||
|
||||
-- Check if value is finite
|
||||
function math_ext.isfinite(x)
|
||||
return x > -math_ext.infinity and x < math_ext.infinity
|
||||
end
|
||||
|
||||
-- Sign function (-1, 0, 1)
|
||||
function math_ext.sign(x)
|
||||
return x > 0 and 1 or (x < 0 and -1 or 0)
|
||||
end
|
||||
|
||||
-- Clamp value between min and max
|
||||
function math_ext.clamp(x, min, max)
|
||||
return x < min and min or (x > max and max or x)
|
||||
end
|
||||
|
||||
-- Linear interpolation
|
||||
function math_ext.lerp(a, b, t)
|
||||
return a + (b - a) * t
|
||||
end
|
||||
|
||||
-- Smooth step interpolation
|
||||
function math_ext.smoothstep(a, b, t)
|
||||
t = math_ext.clamp((t - a) / (b - a), 0, 1)
|
||||
return t * t * (3 - 2 * t)
|
||||
end
|
||||
|
||||
-- Map value from one range to another
|
||||
function math_ext.map(x, in_min, in_max, out_min, out_max)
|
||||
return (x - in_min) * (out_max - out_min) / (in_max - in_min) + out_min
|
||||
end
|
||||
|
||||
-- Round to nearest integer
|
||||
function math_ext.round(x)
|
||||
return x >= 0 and math.floor(x + 0.5) or math.ceil(x - 0.5)
|
||||
end
|
||||
|
||||
-- Round to specified decimal places
|
||||
function math_ext.roundto(x, decimals)
|
||||
local mult = 10 ^ (decimals or 0)
|
||||
return math.floor(x * mult + 0.5) / mult
|
||||
end
|
||||
|
||||
-- Normalize angle to [-π, π]
|
||||
function math_ext.normalize_angle(angle)
|
||||
return angle - 2 * math_ext.pi * math.floor((angle + math_ext.pi) / (2 * math_ext.pi))
|
||||
end
|
||||
|
||||
-- Distance between points
|
||||
function math_ext.distance(x1, y1, x2, y2)
|
||||
local dx, dy = x2 - x1, y2 - y1
|
||||
return math.sqrt(dx * dx + dy * dy)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- RANDOM NUMBER FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Random float in range [min, max)
|
||||
function math_ext.randomf(min, max)
|
||||
if not min and not max then
|
||||
return math.random()
|
||||
elseif not max then
|
||||
max = min
|
||||
min = 0
|
||||
end
|
||||
return min + math.random() * (max - min)
|
||||
end
|
||||
|
||||
-- Random integer in range [min, max]
|
||||
function math_ext.randint(min, max)
|
||||
if not max then
|
||||
max = min
|
||||
min = 1
|
||||
end
|
||||
return math.floor(math.random() * (max - min + 1) + min)
|
||||
end
|
||||
|
||||
-- Random boolean with probability p (default 0.5)
|
||||
function math_ext.randboolean(p)
|
||||
p = p or 0.5
|
||||
return math.random() < p
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- STATISTICS FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Sum of values
|
||||
function math_ext.sum(t)
|
||||
if type(t) ~= "table" then return 0 end
|
||||
|
||||
local sum = 0
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
sum = sum + t[i]
|
||||
end
|
||||
end
|
||||
return sum
|
||||
end
|
||||
|
||||
-- Mean (average) of values
|
||||
function math_ext.mean(t)
|
||||
if type(t) ~= "table" or #t == 0 then return 0 end
|
||||
|
||||
local sum = 0
|
||||
local count = 0
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
sum = sum + t[i]
|
||||
count = count + 1
|
||||
end
|
||||
end
|
||||
return count > 0 and sum / count or 0
|
||||
end
|
||||
|
||||
-- Median of values
|
||||
function math_ext.median(t)
|
||||
if type(t) ~= "table" or #t == 0 then return 0 end
|
||||
|
||||
local nums = {}
|
||||
local count = 0
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
count = count + 1
|
||||
nums[count] = t[i]
|
||||
end
|
||||
end
|
||||
|
||||
if count == 0 then return 0 end
|
||||
|
||||
table.sort(nums)
|
||||
|
||||
if count % 2 == 0 then
|
||||
return (nums[count/2] + nums[count/2 + 1]) / 2
|
||||
else
|
||||
return nums[math.ceil(count/2)]
|
||||
end
|
||||
end
|
||||
|
||||
-- Variance of values
|
||||
function math_ext.variance(t)
|
||||
if type(t) ~= "table" then return 0 end
|
||||
|
||||
local count = 0
|
||||
local m = math_ext.mean(t)
|
||||
local sum = 0
|
||||
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
local dev = t[i] - m
|
||||
sum = sum + dev * dev
|
||||
count = count + 1
|
||||
end
|
||||
end
|
||||
|
||||
return count > 1 and sum / count or 0
|
||||
end
|
||||
|
||||
-- Standard deviation
|
||||
function math_ext.stdev(t)
|
||||
return math.sqrt(math_ext.variance(t))
|
||||
end
|
||||
|
||||
-- Population variance
|
||||
function math_ext.pvariance(t)
|
||||
if type(t) ~= "table" then return 0 end
|
||||
|
||||
local count = 0
|
||||
local m = math_ext.mean(t)
|
||||
local sum = 0
|
||||
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
local dev = t[i] - m
|
||||
sum = sum + dev * dev
|
||||
count = count + 1
|
||||
end
|
||||
end
|
||||
|
||||
return count > 0 and sum / count or 0
|
||||
end
|
||||
|
||||
-- Population standard deviation
|
||||
function math_ext.pstdev(t)
|
||||
return math.sqrt(math_ext.pvariance(t))
|
||||
end
|
||||
|
||||
-- Mode (most common value)
|
||||
function math_ext.mode(t)
|
||||
if type(t) ~= "table" or #t == 0 then return nil end
|
||||
|
||||
local counts = {}
|
||||
local most_frequent = nil
|
||||
local max_count = 0
|
||||
|
||||
for i=1, #t do
|
||||
local v = t[i]
|
||||
counts[v] = (counts[v] or 0) + 1
|
||||
if counts[v] > max_count then
|
||||
max_count = counts[v]
|
||||
most_frequent = v
|
||||
end
|
||||
end
|
||||
|
||||
return most_frequent
|
||||
end
|
||||
|
||||
-- Min and max simultaneously (faster than calling both separately)
|
||||
function math_ext.minmax(t)
|
||||
if type(t) ~= "table" or #t == 0 then return nil, nil end
|
||||
|
||||
local min, max
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
min = t[i]
|
||||
max = t[i]
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if min == nil then return nil, nil end
|
||||
|
||||
for i=1, #t do
|
||||
if type(t[i]) == "number" then
|
||||
if t[i] < min then min = t[i] end
|
||||
if t[i] > max then max = t[i] end
|
||||
end
|
||||
end
|
||||
|
||||
return min, max
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- VECTOR OPERATIONS (2D/3D vectors)
|
||||
-- ======================================================================
|
||||
|
||||
-- 2D Vector operations
|
||||
math_ext.vec2 = {
|
||||
new = function(x, y)
|
||||
return {x = x or 0, y = y or 0}
|
||||
end,
|
||||
|
||||
copy = function(v)
|
||||
return {x = v.x, y = v.y}
|
||||
end,
|
||||
|
||||
add = function(a, b)
|
||||
return {x = a.x + b.x, y = a.y + b.y}
|
||||
end,
|
||||
|
||||
sub = function(a, b)
|
||||
return {x = a.x - b.x, y = a.y - b.y}
|
||||
end,
|
||||
|
||||
mul = function(a, b)
|
||||
if type(b) == "number" then
|
||||
return {x = a.x * b, y = a.y * b}
|
||||
end
|
||||
return {x = a.x * b.x, y = a.y * b.y}
|
||||
end,
|
||||
|
||||
div = function(a, b)
|
||||
if type(b) == "number" then
|
||||
local inv = 1 / b
|
||||
return {x = a.x * inv, y = a.y * inv}
|
||||
end
|
||||
return {x = a.x / b.x, y = a.y / b.y}
|
||||
end,
|
||||
|
||||
dot = function(a, b)
|
||||
return a.x * b.x + a.y * b.y
|
||||
end,
|
||||
|
||||
length = function(v)
|
||||
return math.sqrt(v.x * v.x + v.y * v.y)
|
||||
end,
|
||||
|
||||
length_squared = function(v)
|
||||
return v.x * v.x + v.y * v.y
|
||||
end,
|
||||
|
||||
distance = function(a, b)
|
||||
local dx, dy = b.x - a.x, b.y - a.y
|
||||
return math.sqrt(dx * dx + dy * dy)
|
||||
end,
|
||||
|
||||
distance_squared = function(a, b)
|
||||
local dx, dy = b.x - a.x, b.y - a.y
|
||||
return dx * dx + dy * dy
|
||||
end,
|
||||
|
||||
normalize = function(v)
|
||||
local len = math.sqrt(v.x * v.x + v.y * v.y)
|
||||
if len > 1e-10 then
|
||||
local inv_len = 1 / len
|
||||
return {x = v.x * inv_len, y = v.y * inv_len}
|
||||
end
|
||||
return {x = 0, y = 0}
|
||||
end,
|
||||
|
||||
rotate = function(v, angle)
|
||||
local c, s = math.cos(angle), math.sin(angle)
|
||||
return {
|
||||
x = v.x * c - v.y * s,
|
||||
y = v.x * s + v.y * c
|
||||
}
|
||||
end,
|
||||
|
||||
angle = function(v)
|
||||
return math.atan2(v.y, v.x)
|
||||
end,
|
||||
|
||||
lerp = function(a, b, t)
|
||||
t = math_ext.clamp(t, 0, 1)
|
||||
return {
|
||||
x = a.x + (b.x - a.x) * t,
|
||||
y = a.y + (b.y - a.y) * t
|
||||
}
|
||||
end,
|
||||
|
||||
reflect = function(v, normal)
|
||||
local dot = v.x * normal.x + v.y * normal.y
|
||||
return {
|
||||
x = v.x - 2 * dot * normal.x,
|
||||
y = v.y - 2 * dot * normal.y
|
||||
}
|
||||
end
|
||||
}
|
||||
|
||||
-- 3D Vector operations
|
||||
math_ext.vec3 = {
|
||||
new = function(x, y, z)
|
||||
return {x = x or 0, y = y or 0, z = z or 0}
|
||||
end,
|
||||
|
||||
copy = function(v)
|
||||
return {x = v.x, y = v.y, z = v.z}
|
||||
end,
|
||||
|
||||
add = function(a, b)
|
||||
return {x = a.x + b.x, y = a.y + b.y, z = a.z + b.z}
|
||||
end,
|
||||
|
||||
sub = function(a, b)
|
||||
return {x = a.x - b.x, y = a.y - b.y, z = a.z - b.z}
|
||||
end,
|
||||
|
||||
mul = function(a, b)
|
||||
if type(b) == "number" then
|
||||
return {x = a.x * b, y = a.y * b, z = a.z * b}
|
||||
end
|
||||
return {x = a.x * b.x, y = a.y * b.y, z = a.z * b.z}
|
||||
end,
|
||||
|
||||
div = function(a, b)
|
||||
if type(b) == "number" then
|
||||
local inv = 1 / b
|
||||
return {x = a.x * inv, y = a.y * inv, z = a.z * inv}
|
||||
end
|
||||
return {x = a.x / b.x, y = a.y / b.y, z = a.z / b.z}
|
||||
end,
|
||||
|
||||
dot = function(a, b)
|
||||
return a.x * b.x + a.y * b.y + a.z * b.z
|
||||
end,
|
||||
|
||||
cross = function(a, b)
|
||||
return {
|
||||
x = a.y * b.z - a.z * b.y,
|
||||
y = a.z * b.x - a.x * b.z,
|
||||
z = a.x * b.y - a.y * b.x
|
||||
}
|
||||
end,
|
||||
|
||||
length = function(v)
|
||||
return math.sqrt(v.x * v.x + v.y * v.y + v.z * v.z)
|
||||
end,
|
||||
|
||||
length_squared = function(v)
|
||||
return v.x * v.x + v.y * v.y + v.z * v.z
|
||||
end,
|
||||
|
||||
distance = function(a, b)
|
||||
local dx, dy, dz = b.x - a.x, b.y - a.y, b.z - a.z
|
||||
return math.sqrt(dx * dx + dy * dy + dz * dz)
|
||||
end,
|
||||
|
||||
distance_squared = function(a, b)
|
||||
local dx, dy, dz = b.x - a.x, b.y - a.y, b.z - a.z
|
||||
return dx * dx + dy * dy + dz * dz
|
||||
end,
|
||||
|
||||
normalize = function(v)
|
||||
local len = math.sqrt(v.x * v.x + v.y * v.y + v.z * v.z)
|
||||
if len > 1e-10 then
|
||||
local inv_len = 1 / len
|
||||
return {x = v.x * inv_len, y = v.y * inv_len, z = v.z * inv_len}
|
||||
end
|
||||
return {x = 0, y = 0, z = 0}
|
||||
end,
|
||||
|
||||
lerp = function(a, b, t)
|
||||
t = math_ext.clamp(t, 0, 1)
|
||||
return {
|
||||
x = a.x + (b.x - a.x) * t,
|
||||
y = a.y + (b.y - a.y) * t,
|
||||
z = a.z + (b.z - a.z) * t
|
||||
}
|
||||
end,
|
||||
|
||||
reflect = function(v, normal)
|
||||
local dot = v.x * normal.x + v.y * normal.y + v.z * normal.z
|
||||
return {
|
||||
x = v.x - 2 * dot * normal.x,
|
||||
y = v.y - 2 * dot * normal.y,
|
||||
z = v.z - 2 * dot * normal.z
|
||||
}
|
||||
end
|
||||
}
|
||||
|
||||
-- ======================================================================
|
||||
-- MATRIX OPERATIONS (2x2 and 3x3 matrices)
|
||||
-- ======================================================================
|
||||
|
||||
math_ext.mat2 = {
|
||||
-- Create a new 2x2 matrix
|
||||
new = function(a, b, c, d)
|
||||
return {
|
||||
{a or 1, b or 0},
|
||||
{c or 0, d or 1}
|
||||
}
|
||||
end,
|
||||
|
||||
-- Create identity matrix
|
||||
identity = function()
|
||||
return {{1, 0}, {0, 1}}
|
||||
end,
|
||||
|
||||
-- Matrix multiplication
|
||||
mul = function(a, b)
|
||||
return {
|
||||
{
|
||||
a[1][1] * b[1][1] + a[1][2] * b[2][1],
|
||||
a[1][1] * b[1][2] + a[1][2] * b[2][2]
|
||||
},
|
||||
{
|
||||
a[2][1] * b[1][1] + a[2][2] * b[2][1],
|
||||
a[2][1] * b[1][2] + a[2][2] * b[2][2]
|
||||
}
|
||||
}
|
||||
end,
|
||||
|
||||
-- Determinant
|
||||
det = function(m)
|
||||
return m[1][1] * m[2][2] - m[1][2] * m[2][1]
|
||||
end,
|
||||
|
||||
-- Inverse matrix
|
||||
inverse = function(m)
|
||||
local det = m[1][1] * m[2][2] - m[1][2] * m[2][1]
|
||||
if math.abs(det) < 1e-10 then
|
||||
return nil -- Matrix is not invertible
|
||||
end
|
||||
|
||||
local inv_det = 1 / det
|
||||
return {
|
||||
{m[2][2] * inv_det, -m[1][2] * inv_det},
|
||||
{-m[2][1] * inv_det, m[1][1] * inv_det}
|
||||
}
|
||||
end,
|
||||
|
||||
-- Rotation matrix
|
||||
rotation = function(angle)
|
||||
local cos, sin = math.cos(angle), math.sin(angle)
|
||||
return {
|
||||
{cos, -sin},
|
||||
{sin, cos}
|
||||
}
|
||||
end,
|
||||
|
||||
-- Apply matrix to vector
|
||||
transform = function(m, v)
|
||||
return {
|
||||
x = m[1][1] * v.x + m[1][2] * v.y,
|
||||
y = m[2][1] * v.x + m[2][2] * v.y
|
||||
}
|
||||
end,
|
||||
|
||||
-- Scale matrix
|
||||
scale = function(sx, sy)
|
||||
sy = sy or sx
|
||||
return {
|
||||
{sx, 0},
|
||||
{0, sy}
|
||||
}
|
||||
end
|
||||
}
|
||||
|
||||
math_ext.mat3 = {
|
||||
-- Create identity matrix 3x3
|
||||
identity = function()
|
||||
return {
|
||||
{1, 0, 0},
|
||||
{0, 1, 0},
|
||||
{0, 0, 1}
|
||||
}
|
||||
end,
|
||||
|
||||
-- Create a 2D transformation matrix (translation, rotation, scale)
|
||||
transform = function(x, y, angle, sx, sy)
|
||||
sx = sx or 1
|
||||
sy = sy or sx
|
||||
local cos, sin = math.cos(angle), math.sin(angle)
|
||||
return {
|
||||
{cos * sx, -sin * sy, x},
|
||||
{sin * sx, cos * sy, y},
|
||||
{0, 0, 1}
|
||||
}
|
||||
end,
|
||||
|
||||
-- Matrix multiplication
|
||||
mul = function(a, b)
|
||||
local result = {
|
||||
{0, 0, 0},
|
||||
{0, 0, 0},
|
||||
{0, 0, 0}
|
||||
}
|
||||
|
||||
for i = 1, 3 do
|
||||
for j = 1, 3 do
|
||||
for k = 1, 3 do
|
||||
result[i][j] = result[i][j] + a[i][k] * b[k][j]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return result
|
||||
end,
|
||||
|
||||
-- Apply matrix to point (homogeneous coordinates)
|
||||
transform_point = function(m, v)
|
||||
local x = m[1][1] * v.x + m[1][2] * v.y + m[1][3]
|
||||
local y = m[2][1] * v.x + m[2][2] * v.y + m[2][3]
|
||||
local w = m[3][1] * v.x + m[3][2] * v.y + m[3][3]
|
||||
|
||||
if math.abs(w) < 1e-10 then
|
||||
return {x = 0, y = 0}
|
||||
end
|
||||
|
||||
return {x = x / w, y = y / w}
|
||||
end,
|
||||
|
||||
-- Translation matrix
|
||||
translation = function(x, y)
|
||||
return {
|
||||
{1, 0, x},
|
||||
{0, 1, y},
|
||||
{0, 0, 1}
|
||||
}
|
||||
end,
|
||||
|
||||
-- Rotation matrix
|
||||
rotation = function(angle)
|
||||
local cos, sin = math.cos(angle), math.sin(angle)
|
||||
return {
|
||||
{cos, -sin, 0},
|
||||
{sin, cos, 0},
|
||||
{0, 0, 1}
|
||||
}
|
||||
end,
|
||||
|
||||
-- Scale matrix
|
||||
scale = function(sx, sy)
|
||||
sy = sy or sx
|
||||
return {
|
||||
{sx, 0, 0},
|
||||
{0, sy, 0},
|
||||
{0, 0, 1}
|
||||
}
|
||||
end,
|
||||
|
||||
-- Determinant
|
||||
det = function(m)
|
||||
return m[1][1] * (m[2][2] * m[3][3] - m[2][3] * m[3][2]) -
|
||||
m[1][2] * (m[2][1] * m[3][3] - m[2][3] * m[3][1]) +
|
||||
m[1][3] * (m[2][1] * m[3][2] - m[2][2] * m[3][1])
|
||||
end
|
||||
}
|
||||
|
||||
-- ======================================================================
|
||||
-- GEOMETRY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
math_ext.geometry = {
|
||||
-- Distance from point to line
|
||||
point_line_distance = function(px, py, x1, y1, x2, y2)
|
||||
local dx, dy = x2 - x1, y2 - y1
|
||||
local len_sq = dx * dx + dy * dy
|
||||
|
||||
if len_sq < 1e-10 then
|
||||
return math_ext.distance(px, py, x1, y1)
|
||||
end
|
||||
|
||||
local t = ((px - x1) * dx + (py - y1) * dy) / len_sq
|
||||
t = math_ext.clamp(t, 0, 1)
|
||||
|
||||
local nearestX = x1 + t * dx
|
||||
local nearestY = y1 + t * dy
|
||||
|
||||
return math_ext.distance(px, py, nearestX, nearestY)
|
||||
end,
|
||||
|
||||
-- Check if point is inside polygon
|
||||
point_in_polygon = function(px, py, vertices)
|
||||
local inside = false
|
||||
local n = #vertices / 2
|
||||
|
||||
for i = 1, n do
|
||||
local x1, y1 = vertices[i*2-1], vertices[i*2]
|
||||
local x2, y2
|
||||
|
||||
if i == n then
|
||||
x2, y2 = vertices[1], vertices[2]
|
||||
else
|
||||
x2, y2 = vertices[i*2+1], vertices[i*2+2]
|
||||
end
|
||||
|
||||
if ((y1 > py) ~= (y2 > py)) and
|
||||
(px < (x2 - x1) * (py - y1) / (y2 - y1) + x1) then
|
||||
inside = not inside
|
||||
end
|
||||
end
|
||||
|
||||
return inside
|
||||
end,
|
||||
|
||||
-- Area of a triangle
|
||||
triangle_area = function(x1, y1, x2, y2, x3, y3)
|
||||
return math.abs((x1 * (y2 - y3) + x2 * (y3 - y1) + x3 * (y1 - y2)) / 2)
|
||||
end,
|
||||
|
||||
-- Check if point is inside triangle
|
||||
point_in_triangle = function(px, py, x1, y1, x2, y2, x3, y3)
|
||||
local area = math_ext.geometry.triangle_area(x1, y1, x2, y2, x3, y3)
|
||||
local area1 = math_ext.geometry.triangle_area(px, py, x2, y2, x3, y3)
|
||||
local area2 = math_ext.geometry.triangle_area(x1, y1, px, py, x3, y3)
|
||||
local area3 = math_ext.geometry.triangle_area(x1, y1, x2, y2, px, py)
|
||||
|
||||
return math.abs(area - (area1 + area2 + area3)) < 1e-10
|
||||
end,
|
||||
|
||||
-- Check if two line segments intersect
|
||||
line_intersect = function(x1, y1, x2, y2, x3, y3, x4, y4)
|
||||
local d = (y4 - y3) * (x2 - x1) - (x4 - x3) * (y2 - y1)
|
||||
|
||||
if math.abs(d) < 1e-10 then
|
||||
return false, nil, nil -- Lines are parallel
|
||||
end
|
||||
|
||||
local ua = ((x4 - x3) * (y1 - y3) - (y4 - y3) * (x1 - x3)) / d
|
||||
local ub = ((x2 - x1) * (y1 - y3) - (y2 - y1) * (x1 - x3)) / d
|
||||
|
||||
if ua >= 0 and ua <= 1 and ub >= 0 and ub <= 1 then
|
||||
local x = x1 + ua * (x2 - x1)
|
||||
local y = y1 + ua * (y2 - y1)
|
||||
return true, x, y
|
||||
end
|
||||
|
||||
return false, nil, nil
|
||||
end,
|
||||
|
||||
-- Closest point on line segment to point
|
||||
closest_point_on_segment = function(px, py, x1, y1, x2, y2)
|
||||
local dx, dy = x2 - x1, y2 - y1
|
||||
local len_sq = dx * dx + dy * dy
|
||||
|
||||
if len_sq < 1e-10 then
|
||||
return x1, y1
|
||||
end
|
||||
|
||||
local t = ((px - x1) * dx + (py - y1) * dy) / len_sq
|
||||
t = math_ext.clamp(t, 0, 1)
|
||||
|
||||
return x1 + t * dx, y1 + t * dy
|
||||
end
|
||||
}
|
||||
|
||||
-- ======================================================================
|
||||
-- INTERPOLATION FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
math_ext.interpolation = {
|
||||
-- Cubic Bezier interpolation
|
||||
bezier = function(t, p0, p1, p2, p3)
|
||||
t = math_ext.clamp(t, 0, 1)
|
||||
local t2 = t * t
|
||||
local t3 = t2 * t
|
||||
local mt = 1 - t
|
||||
local mt2 = mt * mt
|
||||
local mt3 = mt2 * mt
|
||||
|
||||
return p0 * mt3 + 3 * p1 * mt2 * t + 3 * p2 * mt * t2 + p3 * t3
|
||||
end,
|
||||
|
||||
-- Catmull-Rom spline interpolation
|
||||
catmull_rom = function(t, p0, p1, p2, p3)
|
||||
t = math_ext.clamp(t, 0, 1)
|
||||
local t2 = t * t
|
||||
local t3 = t2 * t
|
||||
|
||||
return 0.5 * (
|
||||
(2 * p1) +
|
||||
(-p0 + p2) * t +
|
||||
(2 * p0 - 5 * p1 + 4 * p2 - p3) * t2 +
|
||||
(-p0 + 3 * p1 - 3 * p2 + p3) * t3
|
||||
)
|
||||
end,
|
||||
|
||||
-- Hermite interpolation
|
||||
hermite = function(t, p0, p1, m0, m1)
|
||||
t = math_ext.clamp(t, 0, 1)
|
||||
local t2 = t * t
|
||||
local t3 = t2 * t
|
||||
local h00 = 2 * t3 - 3 * t2 + 1
|
||||
local h10 = t3 - 2 * t2 + t
|
||||
local h01 = -2 * t3 + 3 * t2
|
||||
local h11 = t3 - t2
|
||||
|
||||
return h00 * p0 + h10 * m0 + h01 * p1 + h11 * m1
|
||||
end,
|
||||
|
||||
-- Quadratic Bezier interpolation
|
||||
quadratic_bezier = function(t, p0, p1, p2)
|
||||
t = math_ext.clamp(t, 0, 1)
|
||||
local mt = 1 - t
|
||||
return mt * mt * p0 + 2 * mt * t * p1 + t * t * p2
|
||||
end,
|
||||
|
||||
-- Step interpolation
|
||||
step = function(t, edge, x)
|
||||
return t < edge and 0 or x
|
||||
end,
|
||||
|
||||
-- Smoothstep interpolation
|
||||
smoothstep = function(edge0, edge1, x)
|
||||
local t = math_ext.clamp((x - edge0) / (edge1 - edge0), 0, 1)
|
||||
return t * t * (3 - 2 * t)
|
||||
end,
|
||||
|
||||
-- Smootherstep interpolation (Ken Perlin)
|
||||
smootherstep = function(edge0, edge1, x)
|
||||
local t = math_ext.clamp((x - edge0) / (edge1 - edge0), 0, 1)
|
||||
return t * t * t * (t * (t * 6 - 15) + 10)
|
||||
end
|
||||
}
|
||||
|
||||
return math_ext
|
@ -1,190 +0,0 @@
|
||||
-- render.lua
|
||||
|
||||
-- Template processing with code execution
|
||||
function render(template_str, env)
|
||||
local function is_control_structure(code)
|
||||
-- Check if code is a control structure that doesn't produce output
|
||||
local trimmed = code:match("^%s*(.-)%s*$")
|
||||
return trimmed == "else" or
|
||||
trimmed == "end" or
|
||||
trimmed:match("^if%s") or
|
||||
trimmed:match("^elseif%s") or
|
||||
trimmed:match("^for%s") or
|
||||
trimmed:match("^while%s") or
|
||||
trimmed:match("^repeat%s*$") or
|
||||
trimmed:match("^until%s") or
|
||||
trimmed:match("^do%s*$") or
|
||||
trimmed:match("^local%s") or
|
||||
trimmed:match("^function%s") or
|
||||
trimmed:match(".*=%s*function%s*%(") or
|
||||
trimmed:match(".*then%s*$") or
|
||||
trimmed:match(".*do%s*$")
|
||||
end
|
||||
|
||||
local pos, chunks = 1, {}
|
||||
while pos <= #template_str do
|
||||
local unescaped_start = template_str:find("{{{", pos, true)
|
||||
local escaped_start = template_str:find("{{", pos, true)
|
||||
|
||||
local start, tag_type, open_len
|
||||
if unescaped_start and (not escaped_start or unescaped_start <= escaped_start) then
|
||||
start, tag_type, open_len = unescaped_start, "-", 3
|
||||
elseif escaped_start then
|
||||
start, tag_type, open_len = escaped_start, "=", 2
|
||||
else
|
||||
table.insert(chunks, template_str:sub(pos))
|
||||
break
|
||||
end
|
||||
|
||||
if start > pos then
|
||||
table.insert(chunks, template_str:sub(pos, start-1))
|
||||
end
|
||||
|
||||
pos = start + open_len
|
||||
local close_tag = tag_type == "-" and "}}}" or "}}"
|
||||
local close_start, close_stop = template_str:find(close_tag, pos, true)
|
||||
if not close_start then
|
||||
error("Failed to find closing tag at position " .. pos)
|
||||
end
|
||||
|
||||
local code = template_str:sub(pos, close_start-1):match("^%s*(.-)%s*$")
|
||||
local is_control = is_control_structure(code)
|
||||
|
||||
table.insert(chunks, {tag_type, code, pos, is_control})
|
||||
pos = close_stop + 1
|
||||
end
|
||||
|
||||
local buffer = {"local _tostring, _escape, _b, _b_i = ...\n"}
|
||||
for _, chunk in ipairs(chunks) do
|
||||
local t = type(chunk)
|
||||
if t == "string" then
|
||||
table.insert(buffer, "_b_i = _b_i + 1\n")
|
||||
table.insert(buffer, "_b[_b_i] = " .. string.format("%q", chunk) .. "\n")
|
||||
else
|
||||
local tag_type, code, pos, is_control = chunk[1], chunk[2], chunk[3], chunk[4]
|
||||
|
||||
if is_control then
|
||||
-- Control structure - just insert as raw Lua code
|
||||
table.insert(buffer, "--[[" .. pos .. "]] " .. code .. "\n")
|
||||
elseif tag_type == "=" then
|
||||
-- Simple variable check
|
||||
if code:match("^[%w_]+$") then
|
||||
table.insert(buffer, "_b_i = _b_i + 1\n")
|
||||
table.insert(buffer, "--[[" .. pos .. "]] _b[_b_i] = _escape(_tostring(" .. code .. "))\n")
|
||||
else
|
||||
-- Expression output with escaping
|
||||
table.insert(buffer, "_b_i = _b_i + 1\n")
|
||||
table.insert(buffer, "--[[" .. pos .. "]] _b[_b_i] = _escape(_tostring(" .. code .. "))\n")
|
||||
end
|
||||
elseif tag_type == "-" then
|
||||
-- Unescaped output
|
||||
table.insert(buffer, "_b_i = _b_i + 1\n")
|
||||
table.insert(buffer, "--[[" .. pos .. "]] _b[_b_i] = _tostring(" .. code .. ")\n")
|
||||
end
|
||||
end
|
||||
end
|
||||
table.insert(buffer, "return _b")
|
||||
|
||||
local generated_code = table.concat(buffer)
|
||||
|
||||
-- DEBUG: Uncomment to see generated code
|
||||
-- print("Generated Lua code:")
|
||||
-- print(generated_code)
|
||||
-- print("---")
|
||||
|
||||
local fn, err = loadstring(generated_code)
|
||||
if not fn then
|
||||
print("Generated code that failed to compile:")
|
||||
print(generated_code)
|
||||
error(err)
|
||||
end
|
||||
|
||||
env = env or {}
|
||||
local runtime_env = setmetatable({}, {__index = function(_, k) return env[k] or _G[k] end})
|
||||
setfenv(fn, runtime_env)
|
||||
|
||||
local output_buffer = {}
|
||||
fn(tostring, html_special_chars, output_buffer, 0)
|
||||
return table.concat(output_buffer)
|
||||
end
|
||||
|
||||
-- Named placeholder processing
|
||||
function parse(template_str, env)
|
||||
local pos, output = 1, {}
|
||||
env = env or {}
|
||||
|
||||
while pos <= #template_str do
|
||||
local unescaped_start, unescaped_end, unescaped_name = template_str:find("{{{%s*([%w_]+)%s*}}}", pos)
|
||||
local escaped_start, escaped_end, escaped_name = template_str:find("{{%s*([%w_]+)%s*}}", pos)
|
||||
|
||||
local next_pos, placeholder_end, name, escaped
|
||||
if unescaped_start and (not escaped_start or unescaped_start <= escaped_start) then
|
||||
next_pos, placeholder_end, name, escaped = unescaped_start, unescaped_end, unescaped_name, false
|
||||
elseif escaped_start then
|
||||
next_pos, placeholder_end, name, escaped = escaped_start, escaped_end, escaped_name, true
|
||||
else
|
||||
local text = template_str:sub(pos)
|
||||
if text and #text > 0 then
|
||||
table.insert(output, text)
|
||||
end
|
||||
break
|
||||
end
|
||||
|
||||
local text = template_str:sub(pos, next_pos - 1)
|
||||
if text and #text > 0 then
|
||||
table.insert(output, text)
|
||||
end
|
||||
|
||||
local value = env[name]
|
||||
local str = tostring(value or "")
|
||||
if escaped then
|
||||
str = html_special_chars(str)
|
||||
end
|
||||
table.insert(output, str)
|
||||
|
||||
pos = placeholder_end + 1
|
||||
end
|
||||
|
||||
return table.concat(output)
|
||||
end
|
||||
|
||||
-- Indexed placeholder processing
|
||||
function iparse(template_str, values)
|
||||
local pos, output, value_index = 1, {}, 1
|
||||
values = values or {}
|
||||
|
||||
while pos <= #template_str do
|
||||
local unescaped_start, unescaped_end = template_str:find("{{{}}}", pos, true)
|
||||
local escaped_start, escaped_end = template_str:find("{{}}", pos, true)
|
||||
|
||||
local next_pos, placeholder_end, escaped
|
||||
if unescaped_start and (not escaped_start or unescaped_start <= escaped_start) then
|
||||
next_pos, placeholder_end, escaped = unescaped_start, unescaped_end, false
|
||||
elseif escaped_start then
|
||||
next_pos, placeholder_end, escaped = escaped_start, escaped_end, true
|
||||
else
|
||||
local text = template_str:sub(pos)
|
||||
if text and #text > 0 then
|
||||
table.insert(output, text)
|
||||
end
|
||||
break
|
||||
end
|
||||
|
||||
local text = template_str:sub(pos, next_pos - 1)
|
||||
if text and #text > 0 then
|
||||
table.insert(output, text)
|
||||
end
|
||||
|
||||
local value = values[value_index]
|
||||
local str = tostring(value or "")
|
||||
if escaped then
|
||||
str = html_special_chars(str)
|
||||
end
|
||||
table.insert(output, str)
|
||||
|
||||
pos = placeholder_end + 1
|
||||
value_index = value_index + 1
|
||||
end
|
||||
|
||||
return table.concat(output)
|
||||
end
|
@ -1,36 +0,0 @@
|
||||
-- sandbox.lua
|
||||
|
||||
function __execute(script_func, ctx, response)
|
||||
-- Store context and response globally for function access
|
||||
__ctx = ctx
|
||||
__response = response
|
||||
_G.ctx = ctx
|
||||
|
||||
-- Create a coroutine for script execution to handle early exits
|
||||
local co = coroutine.create(function()
|
||||
return script_func()
|
||||
end)
|
||||
|
||||
local ok, result = coroutine.resume(co)
|
||||
|
||||
-- Clean up
|
||||
__ctx = nil
|
||||
__response = nil
|
||||
|
||||
if not ok then
|
||||
-- Real error during script execution
|
||||
error(result, 0)
|
||||
end
|
||||
|
||||
-- Check if exit was requested
|
||||
if result == "__EXIT__" then
|
||||
return {nil, response}
|
||||
end
|
||||
|
||||
return {result, response}
|
||||
end
|
||||
|
||||
-- Exit sentinel using coroutine yield instead of error
|
||||
function exit()
|
||||
coroutine.yield("__EXIT__")
|
||||
end
|
@ -1,179 +0,0 @@
|
||||
-- session.lua
|
||||
|
||||
function session_set(key, value)
|
||||
__response.session = __response.session or {}
|
||||
__response.session[key] = value
|
||||
if __ctx.session and __ctx.session.data then
|
||||
__ctx.session.data[key] = value
|
||||
end
|
||||
end
|
||||
|
||||
function session_get(key)
|
||||
return __ctx.session and __ctx.session.data and __ctx.session.data[key]
|
||||
end
|
||||
|
||||
function session_id()
|
||||
return __ctx.session and __ctx.session.id
|
||||
end
|
||||
|
||||
function session_get_all()
|
||||
if __ctx.session and __ctx.session.data then
|
||||
local copy = {}
|
||||
for k, v in pairs(__ctx.session.data) do
|
||||
copy[k] = v
|
||||
end
|
||||
return copy
|
||||
end
|
||||
return {}
|
||||
end
|
||||
|
||||
function session_delete(key)
|
||||
__response.session = __response.session or {}
|
||||
__response.session[key] = "__DELETE__"
|
||||
if __ctx.session and __ctx.session.data then
|
||||
__ctx.session.data[key] = nil
|
||||
end
|
||||
end
|
||||
|
||||
function session_clear()
|
||||
__response.session = {__clear_all = true}
|
||||
if __ctx.session and __ctx.session.data then
|
||||
for k in pairs(__ctx.session.data) do
|
||||
__ctx.session.data[k] = nil
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function session_flash(key, value)
|
||||
__response.flash = __response.flash or {}
|
||||
__response.flash[key] = value
|
||||
end
|
||||
|
||||
function session_get_flash(key)
|
||||
-- Check current flash data first
|
||||
if __response.flash and __response.flash[key] ~= nil then
|
||||
return __response.flash[key]
|
||||
end
|
||||
|
||||
-- Check session flash data
|
||||
if __ctx.session and __ctx.session.flash and __ctx.session.flash[key] ~= nil then
|
||||
return __ctx.session.flash[key]
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
function session_has_flash(key)
|
||||
-- Check current flash
|
||||
if __response.flash and __response.flash[key] ~= nil then
|
||||
return true
|
||||
end
|
||||
|
||||
-- Check session flash
|
||||
if __ctx.session and __ctx.session.flash and __ctx.session.flash[key] ~= nil then
|
||||
return true
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
function session_get_all_flash()
|
||||
local flash = {}
|
||||
|
||||
-- Add session flash data first
|
||||
if __ctx.session and __ctx.session.flash then
|
||||
for k, v in pairs(__ctx.session.flash) do
|
||||
flash[k] = v
|
||||
end
|
||||
end
|
||||
|
||||
-- Add current response flash (overwrites session flash if same key)
|
||||
if __response.flash then
|
||||
for k, v in pairs(__response.flash) do
|
||||
flash[k] = v
|
||||
end
|
||||
end
|
||||
|
||||
return flash
|
||||
end
|
||||
|
||||
function session_flash_now(key, value)
|
||||
-- Flash for current request only (not persisted)
|
||||
_G._current_flash = _G._current_flash or {}
|
||||
_G._current_flash[key] = value
|
||||
end
|
||||
|
||||
function session_get_flash_now(key)
|
||||
return _G._current_flash and _G._current_flash[key]
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- FLASH HELPER FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
function flash_success(message)
|
||||
session_flash("success", message)
|
||||
end
|
||||
|
||||
function flash_error(message)
|
||||
session_flash("error", message)
|
||||
end
|
||||
|
||||
function flash_warning(message)
|
||||
session_flash("warning", message)
|
||||
end
|
||||
|
||||
function flash_info(message)
|
||||
session_flash("info", message)
|
||||
end
|
||||
|
||||
function flash_message(type, message)
|
||||
session_flash(type, message)
|
||||
end
|
||||
|
||||
-- Get flash messages by type
|
||||
function get_flash_success()
|
||||
return session_get_flash("success")
|
||||
end
|
||||
|
||||
function get_flash_error()
|
||||
return session_get_flash("error")
|
||||
end
|
||||
|
||||
function get_flash_warning()
|
||||
return session_get_flash("warning")
|
||||
end
|
||||
|
||||
function get_flash_info()
|
||||
return session_get_flash("info")
|
||||
end
|
||||
|
||||
-- Check if flash messages exist
|
||||
function has_flash_success()
|
||||
return session_has_flash("success")
|
||||
end
|
||||
|
||||
function has_flash_error()
|
||||
return session_has_flash("error")
|
||||
end
|
||||
|
||||
function has_flash_warning()
|
||||
return session_has_flash("warning")
|
||||
end
|
||||
|
||||
function has_flash_info()
|
||||
return session_has_flash("info")
|
||||
end
|
||||
|
||||
function redirect_with_flash(url, type, message, status)
|
||||
session_flash(type or "info", message)
|
||||
http_redirect(url, status)
|
||||
end
|
||||
|
||||
function redirect_with_success(url, message, status)
|
||||
redirect_with_flash(url, "success", message, status)
|
||||
end
|
||||
|
||||
function redirect_with_error(url, message, status)
|
||||
redirect_with_flash(url, "error", message, status)
|
||||
end
|
@ -1,299 +0,0 @@
|
||||
-- sqlite.lua
|
||||
|
||||
local function normalize_params(params, ...)
|
||||
if type(params) == "table" then return params end
|
||||
local args = {...}
|
||||
if #args > 0 or params ~= nil then
|
||||
table.insert(args, 1, params)
|
||||
return args
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
local connection_mt = {
|
||||
__index = {
|
||||
query = function(self, query, params, ...)
|
||||
if type(query) ~= "string" then
|
||||
error("connection:query: query must be a string", 2)
|
||||
end
|
||||
|
||||
local normalized_params = normalize_params(params, ...)
|
||||
return __sqlite_query(self.db_name, query, normalized_params, __STATE_INDEX)
|
||||
end,
|
||||
|
||||
exec = function(self, query, params, ...)
|
||||
if type(query) ~= "string" then
|
||||
error("connection:exec: query must be a string", 2)
|
||||
end
|
||||
|
||||
local normalized_params = normalize_params(params, ...)
|
||||
return __sqlite_exec(self.db_name, query, normalized_params, __STATE_INDEX)
|
||||
end,
|
||||
|
||||
get_one = function(self, query, params, ...)
|
||||
if type(query) ~= "string" then
|
||||
error("connection:get_one: query must be a string", 2)
|
||||
end
|
||||
|
||||
local normalized_params = normalize_params(params, ...)
|
||||
return __sqlite_get_one(self.db_name, query, normalized_params, __STATE_INDEX)
|
||||
end,
|
||||
|
||||
insert = function(self, table_name, data, columns)
|
||||
if type(data) ~= "table" then
|
||||
error("connection:insert: data must be a table", 2)
|
||||
end
|
||||
|
||||
-- Single object: {col1=val1, col2=val2}
|
||||
if data[1] == nil and next(data) ~= nil then
|
||||
local cols = table.keys(data)
|
||||
local placeholders = table.map(cols, function(_, i) return ":p" .. i end)
|
||||
local params = {}
|
||||
for i, col in ipairs(cols) do
|
||||
params["p" .. i] = data[col]
|
||||
end
|
||||
|
||||
local query = string.format(
|
||||
"INSERT INTO %s (%s) VALUES (%s)",
|
||||
table_name,
|
||||
table.concat(cols, ", "),
|
||||
table.concat(placeholders, ", ")
|
||||
)
|
||||
return self:exec(query, params)
|
||||
end
|
||||
|
||||
-- Array data with columns
|
||||
if columns and type(columns) == "table" then
|
||||
if #data > 0 and type(data[1]) == "table" then
|
||||
-- Multiple rows
|
||||
local value_groups = {}
|
||||
local params = {}
|
||||
local param_idx = 1
|
||||
|
||||
for _, row in ipairs(data) do
|
||||
local row_placeholders = {}
|
||||
for j = 1, #columns do
|
||||
local param_name = "p" .. param_idx
|
||||
table.insert(row_placeholders, ":" .. param_name)
|
||||
params[param_name] = row[j]
|
||||
param_idx = param_idx + 1
|
||||
end
|
||||
table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")")
|
||||
end
|
||||
|
||||
local query = string.format(
|
||||
"INSERT INTO %s (%s) VALUES %s",
|
||||
table_name,
|
||||
table.concat(columns, ", "),
|
||||
table.concat(value_groups, ", ")
|
||||
)
|
||||
return self:exec(query, params)
|
||||
else
|
||||
-- Single row array
|
||||
local placeholders = table.map(columns, function(_, i) return ":p" .. i end)
|
||||
local params = {}
|
||||
for i = 1, #columns do
|
||||
params["p" .. i] = data[i]
|
||||
end
|
||||
|
||||
local query = string.format(
|
||||
"INSERT INTO %s (%s) VALUES (%s)",
|
||||
table_name,
|
||||
table.concat(columns, ", "),
|
||||
table.concat(placeholders, ", ")
|
||||
)
|
||||
return self:exec(query, params)
|
||||
end
|
||||
end
|
||||
|
||||
-- Array of objects
|
||||
if #data > 0 and type(data[1]) == "table" and data[1][1] == nil then
|
||||
local cols = table.keys(data[1])
|
||||
local value_groups = {}
|
||||
local params = {}
|
||||
local param_idx = 1
|
||||
|
||||
for _, row in ipairs(data) do
|
||||
local row_placeholders = {}
|
||||
for _, col in ipairs(cols) do
|
||||
local param_name = "p" .. param_idx
|
||||
table.insert(row_placeholders, ":" .. param_name)
|
||||
params[param_name] = row[col]
|
||||
param_idx = param_idx + 1
|
||||
end
|
||||
table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")")
|
||||
end
|
||||
|
||||
local query = string.format(
|
||||
"INSERT INTO %s (%s) VALUES %s",
|
||||
table_name,
|
||||
table.concat(cols, ", "),
|
||||
table.concat(value_groups, ", ")
|
||||
)
|
||||
return self:exec(query, params)
|
||||
end
|
||||
|
||||
error("connection:insert: invalid data format", 2)
|
||||
end,
|
||||
|
||||
update = function(self, table_name, data, where, where_params, ...)
|
||||
if type(data) ~= "table" or next(data) == nil then
|
||||
return 0
|
||||
end
|
||||
|
||||
local sets = {}
|
||||
local params = {}
|
||||
local param_idx = 1
|
||||
|
||||
for col, val in pairs(data) do
|
||||
local param_name = "p" .. param_idx
|
||||
table.insert(sets, col .. " = :" .. param_name)
|
||||
params[param_name] = val
|
||||
param_idx = param_idx + 1
|
||||
end
|
||||
|
||||
local query = string.format("UPDATE %s SET %s", table_name, table.concat(sets, ", "))
|
||||
|
||||
if where then
|
||||
query = query .. " WHERE " .. where
|
||||
if where_params then
|
||||
local normalized = normalize_params(where_params, ...)
|
||||
if type(normalized) == "table" then
|
||||
for k, v in pairs(normalized) do
|
||||
if type(k) == "string" then
|
||||
params[k] = v
|
||||
else
|
||||
params["w" .. param_idx] = v
|
||||
param_idx = param_idx + 1
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return self:exec(query, params)
|
||||
end,
|
||||
|
||||
create_table = function(self, table_name, ...)
|
||||
local column_definitions = {}
|
||||
local index_definitions = {}
|
||||
|
||||
for _, def_string in ipairs({...}) do
|
||||
if type(def_string) == "string" then
|
||||
local is_unique = false
|
||||
local index_def = def_string
|
||||
|
||||
if string.starts_with(def_string, "UNIQUE INDEX:") then
|
||||
is_unique = true
|
||||
index_def = string.trim(def_string:sub(14))
|
||||
elseif string.starts_with(def_string, "INDEX:") then
|
||||
index_def = string.trim(def_string:sub(7))
|
||||
else
|
||||
table.insert(column_definitions, def_string)
|
||||
goto continue
|
||||
end
|
||||
|
||||
local paren_pos = index_def:find("%(")
|
||||
if not paren_pos then goto continue end
|
||||
|
||||
local index_name = string.trim(index_def:sub(1, paren_pos - 1))
|
||||
local columns_part = index_def:sub(paren_pos + 1):match("^(.-)%)%s*$")
|
||||
if not columns_part then goto continue end
|
||||
|
||||
local columns = table.map(string.split(columns_part, ","), string.trim)
|
||||
|
||||
if #columns > 0 then
|
||||
table.insert(index_definitions, {
|
||||
name = index_name,
|
||||
columns = columns,
|
||||
unique = is_unique
|
||||
})
|
||||
end
|
||||
end
|
||||
::continue::
|
||||
end
|
||||
|
||||
if #column_definitions == 0 then
|
||||
error("connection:create_table: no column definitions specified for table " .. table_name, 2)
|
||||
end
|
||||
|
||||
local statements = {}
|
||||
|
||||
table.insert(statements, string.format(
|
||||
"CREATE TABLE IF NOT EXISTS %s (%s)",
|
||||
table_name,
|
||||
table.concat(column_definitions, ", ")
|
||||
))
|
||||
|
||||
for _, idx in ipairs(index_definitions) do
|
||||
local unique_prefix = idx.unique and "UNIQUE " or ""
|
||||
table.insert(statements, string.format(
|
||||
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
|
||||
unique_prefix,
|
||||
idx.name,
|
||||
table_name,
|
||||
table.concat(idx.columns, ", ")
|
||||
))
|
||||
end
|
||||
|
||||
return self:exec(table.concat(statements, ";\n"))
|
||||
end,
|
||||
|
||||
delete = function(self, table_name, where, params, ...)
|
||||
local query = "DELETE FROM " .. table_name
|
||||
if where then
|
||||
query = query .. " WHERE " .. where
|
||||
end
|
||||
return self:exec(query, normalize_params(params, ...))
|
||||
end,
|
||||
|
||||
exists = function(self, table_name, where, params, ...)
|
||||
if type(table_name) ~= "string" then
|
||||
error("connection:exists: table_name must be a string", 2)
|
||||
end
|
||||
|
||||
local query = "SELECT 1 FROM " .. table_name
|
||||
if where then
|
||||
query = query .. " WHERE " .. where
|
||||
end
|
||||
query = query .. " LIMIT 1"
|
||||
|
||||
local results = self:query(query, normalize_params(params, ...))
|
||||
return #results > 0
|
||||
end,
|
||||
|
||||
begin = function(self)
|
||||
return self:exec("BEGIN TRANSACTION")
|
||||
end,
|
||||
|
||||
commit = function(self)
|
||||
return self:exec("COMMIT")
|
||||
end,
|
||||
|
||||
rollback = function(self)
|
||||
return self:exec("ROLLBACK")
|
||||
end,
|
||||
|
||||
transaction = function(self, callback)
|
||||
self:begin()
|
||||
local success, result = pcall(callback, self)
|
||||
if success then
|
||||
self:commit()
|
||||
return result
|
||||
else
|
||||
self:rollback()
|
||||
error(result, 2)
|
||||
end
|
||||
end
|
||||
}
|
||||
}
|
||||
|
||||
function sqlite(db_name)
|
||||
if type(db_name) ~= "string" then
|
||||
error("sqlite: database name must be a string", 2)
|
||||
end
|
||||
|
||||
return setmetatable({
|
||||
db_name = db_name
|
||||
}, connection_mt)
|
||||
end
|
@ -1,195 +0,0 @@
|
||||
-- string.lua
|
||||
|
||||
local string_ext = {}
|
||||
|
||||
-- ======================================================================
|
||||
-- STRING UTILITY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Trim whitespace from both ends
|
||||
function string_ext.trim(s)
|
||||
if type(s) ~= "string" then return s end
|
||||
return s:match("^%s*(.-)%s*$")
|
||||
end
|
||||
|
||||
-- Split string by delimiter
|
||||
function string_ext.split(s, delimiter)
|
||||
if type(s) ~= "string" then return {} end
|
||||
|
||||
delimiter = delimiter or ","
|
||||
local result = {}
|
||||
for match in (s..delimiter):gmatch("(.-)"..delimiter) do
|
||||
table.insert(result, match)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
-- Check if string starts with prefix
|
||||
function string_ext.starts_with(s, prefix)
|
||||
if type(s) ~= "string" or type(prefix) ~= "string" then return false end
|
||||
return s:sub(1, #prefix) == prefix
|
||||
end
|
||||
|
||||
-- Check if string ends with suffix
|
||||
function string_ext.ends_with(s, suffix)
|
||||
if type(s) ~= "string" or type(suffix) ~= "string" then return false end
|
||||
return suffix == "" or s:sub(-#suffix) == suffix
|
||||
end
|
||||
|
||||
-- Left pad a string
|
||||
function string_ext.pad_left(s, len, char)
|
||||
if type(s) ~= "string" or type(len) ~= "number" then return s end
|
||||
|
||||
char = char or " "
|
||||
if #s >= len then return s end
|
||||
|
||||
return string.rep(char:sub(1,1), len - #s) .. s
|
||||
end
|
||||
|
||||
-- Right pad a string
|
||||
function string_ext.pad_right(s, len, char)
|
||||
if type(s) ~= "string" or type(len) ~= "number" then return s end
|
||||
|
||||
char = char or " "
|
||||
if #s >= len then return s end
|
||||
|
||||
return s .. string.rep(char:sub(1,1), len - #s)
|
||||
end
|
||||
|
||||
-- Center a string
|
||||
function string_ext.center(s, width, char)
|
||||
if type(s) ~= "string" or width <= #s then return s end
|
||||
|
||||
char = char or " "
|
||||
local pad_len = width - #s
|
||||
local left_pad = math.floor(pad_len / 2)
|
||||
local right_pad = pad_len - left_pad
|
||||
|
||||
return string.rep(char:sub(1,1), left_pad) .. s .. string.rep(char:sub(1,1), right_pad)
|
||||
end
|
||||
|
||||
-- Count occurrences of substring
|
||||
function string_ext.count(s, substr)
|
||||
if type(s) ~= "string" or type(substr) ~= "string" or #substr == 0 then return 0 end
|
||||
|
||||
local count, pos = 0, 1
|
||||
while true do
|
||||
pos = s:find(substr, pos, true)
|
||||
if not pos then break end
|
||||
count = count + 1
|
||||
pos = pos + 1
|
||||
end
|
||||
return count
|
||||
end
|
||||
|
||||
-- Capitalize first letter
|
||||
function string_ext.capitalize(s)
|
||||
if type(s) ~= "string" or #s == 0 then return s end
|
||||
return s:sub(1,1):upper() .. s:sub(2)
|
||||
end
|
||||
|
||||
-- Capitalize all words
|
||||
function string_ext.title(s)
|
||||
if type(s) ~= "string" then return s end
|
||||
|
||||
return s:gsub("(%w)([%w]*)", function(first, rest)
|
||||
return first:upper() .. rest:lower()
|
||||
end)
|
||||
end
|
||||
|
||||
-- Insert string at position
|
||||
function string_ext.insert(s, pos, insert_str)
|
||||
if type(s) ~= "string" or type(insert_str) ~= "string" then return s end
|
||||
|
||||
pos = math.max(1, math.min(pos, #s + 1))
|
||||
return s:sub(1, pos - 1) .. insert_str .. s:sub(pos)
|
||||
end
|
||||
|
||||
-- Remove substring
|
||||
function string_ext.remove(s, start, length)
|
||||
if type(s) ~= "string" then return s end
|
||||
|
||||
length = length or 1
|
||||
if start < 1 or start > #s then return s end
|
||||
|
||||
return s:sub(1, start - 1) .. s:sub(start + length)
|
||||
end
|
||||
|
||||
-- Replace substring once
|
||||
function string_ext.replace(s, old, new, n)
|
||||
if type(s) ~= "string" or type(old) ~= "string" or #old == 0 then return s end
|
||||
|
||||
new = new or ""
|
||||
n = n or 1
|
||||
|
||||
return s:gsub(old:gsub("[%-%^%$%(%)%%%.%[%]%*%+%-%?]", "%%%1"), new, n)
|
||||
end
|
||||
|
||||
-- Check if string contains substring
|
||||
function string_ext.contains(s, substr)
|
||||
if type(s) ~= "string" or type(substr) ~= "string" then return false end
|
||||
return s:find(substr, 1, true) ~= nil
|
||||
end
|
||||
|
||||
-- Escape pattern magic characters
|
||||
function string_ext.escape_pattern(s)
|
||||
if type(s) ~= "string" then return s end
|
||||
return s:gsub("[%-%^%$%(%)%%%.%[%]%*%+%-%?]", "%%%1")
|
||||
end
|
||||
|
||||
-- Wrap text at specified width
|
||||
function string_ext.wrap(s, width, indent_first, indent_rest)
|
||||
if type(s) ~= "string" or type(width) ~= "number" then return s end
|
||||
|
||||
width = math.max(1, width)
|
||||
indent_first = indent_first or ""
|
||||
indent_rest = indent_rest or indent_first
|
||||
|
||||
local result = {}
|
||||
local line_prefix = indent_first
|
||||
local pos = 1
|
||||
|
||||
while pos <= #s do
|
||||
local line_width = width - #line_prefix
|
||||
local end_pos = math.min(pos + line_width - 1, #s)
|
||||
|
||||
if end_pos < #s then
|
||||
local last_space = s:sub(pos, end_pos):match(".*%s()")
|
||||
if last_space then
|
||||
end_pos = pos + last_space - 2
|
||||
end
|
||||
end
|
||||
|
||||
table.insert(result, line_prefix .. s:sub(pos, end_pos))
|
||||
pos = end_pos + 1
|
||||
|
||||
-- Skip leading spaces on next line
|
||||
while s:sub(pos, pos) == " " do
|
||||
pos = pos + 1
|
||||
end
|
||||
|
||||
line_prefix = indent_rest
|
||||
end
|
||||
|
||||
return table.concat(result, "\n")
|
||||
end
|
||||
|
||||
-- Limit string length with ellipsis
|
||||
function string_ext.truncate(s, length, ellipsis)
|
||||
if type(s) ~= "string" then return s end
|
||||
|
||||
ellipsis = ellipsis or "..."
|
||||
if #s <= length then return s end
|
||||
|
||||
return s:sub(1, length - #ellipsis) .. ellipsis
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- INSTALL EXTENSIONS INTO STRING LIBRARY
|
||||
-- ======================================================================
|
||||
|
||||
for name, func in pairs(string) do
|
||||
string_ext[name] = func
|
||||
end
|
||||
|
||||
return string_ext
|
1090
runner/lua/table.lua
1090
runner/lua/table.lua
File diff suppressed because it is too large
Load Diff
@ -1,128 +0,0 @@
|
||||
-- time.lua
|
||||
|
||||
local ffi = require('ffi')
|
||||
local is_windows = (ffi.os == "Windows")
|
||||
|
||||
-- Define C structures and functions based on platform
|
||||
if is_windows then
|
||||
ffi.cdef[[
|
||||
typedef struct {
|
||||
int64_t QuadPart;
|
||||
} LARGE_INTEGER;
|
||||
int QueryPerformanceCounter(LARGE_INTEGER* lpPerformanceCount);
|
||||
int QueryPerformanceFrequency(LARGE_INTEGER* lpFrequency);
|
||||
]]
|
||||
else
|
||||
ffi.cdef[[
|
||||
typedef long time_t;
|
||||
typedef struct timeval {
|
||||
long tv_sec;
|
||||
long tv_usec;
|
||||
} timeval;
|
||||
int gettimeofday(struct timeval* tv, void* tz);
|
||||
time_t time(time_t* t);
|
||||
]]
|
||||
end
|
||||
|
||||
local time = {}
|
||||
local has_initialized = false
|
||||
local start_time, timer_freq
|
||||
|
||||
-- Initialize timing system based on platform
|
||||
local function init()
|
||||
if has_initialized then return end
|
||||
|
||||
if ffi.os == "Windows" then
|
||||
local frequency = ffi.new("LARGE_INTEGER")
|
||||
ffi.C.QueryPerformanceFrequency(frequency)
|
||||
timer_freq = tonumber(frequency.QuadPart)
|
||||
|
||||
local counter = ffi.new("LARGE_INTEGER")
|
||||
ffi.C.QueryPerformanceCounter(counter)
|
||||
start_time = tonumber(counter.QuadPart)
|
||||
else
|
||||
-- Nothing special needed for Unix platform init
|
||||
start_time = ffi.C.time(nil)
|
||||
end
|
||||
|
||||
has_initialized = true
|
||||
end
|
||||
|
||||
-- PHP-compatible microtime implementation
|
||||
function time.microtime(get_as_float)
|
||||
init()
|
||||
|
||||
if ffi.os == "Windows" then
|
||||
local counter = ffi.new("LARGE_INTEGER")
|
||||
ffi.C.QueryPerformanceCounter(counter)
|
||||
local now = tonumber(counter.QuadPart)
|
||||
local seconds = math.floor((now - start_time) / timer_freq)
|
||||
local microseconds = ((now - start_time) % timer_freq) * 1000000 / timer_freq
|
||||
|
||||
if get_as_float then
|
||||
return seconds + microseconds / 1000000
|
||||
else
|
||||
return string.format("0.%06d %d", microseconds, seconds)
|
||||
end
|
||||
else
|
||||
local tv = ffi.new("struct timeval")
|
||||
ffi.C.gettimeofday(tv, nil)
|
||||
|
||||
if get_as_float then
|
||||
return tonumber(tv.tv_sec) + tonumber(tv.tv_usec) / 1000000
|
||||
else
|
||||
return string.format("0.%06d %d", tv.tv_usec, tv.tv_sec)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- High-precision monotonic timer (returns seconds with microsecond precision)
|
||||
function time.monotonic()
|
||||
init()
|
||||
|
||||
if ffi.os == "Windows" then
|
||||
local counter = ffi.new("LARGE_INTEGER")
|
||||
ffi.C.QueryPerformanceCounter(counter)
|
||||
local now = tonumber(counter.QuadPart)
|
||||
return (now - start_time) / timer_freq
|
||||
else
|
||||
local tv = ffi.new("struct timeval")
|
||||
ffi.C.gettimeofday(tv, nil)
|
||||
return tonumber(tv.tv_sec) - start_time + tonumber(tv.tv_usec) / 1000000
|
||||
end
|
||||
end
|
||||
|
||||
-- Benchmark function that measures execution time
|
||||
function time.benchmark(func, iterations, warmup)
|
||||
iterations = iterations or 1000
|
||||
warmup = warmup or 10
|
||||
|
||||
-- Warmup
|
||||
for i=1, warmup do func() end
|
||||
|
||||
local start = time.microtime(true)
|
||||
for i=1, iterations do
|
||||
func()
|
||||
end
|
||||
local finish = time.microtime(true)
|
||||
|
||||
local elapsed = (finish - start) * 1000000 -- Convert to microseconds
|
||||
return elapsed / iterations
|
||||
end
|
||||
|
||||
-- Simple sleep function using coroutine yielding
|
||||
function time.sleep(seconds)
|
||||
if type(seconds) ~= "number" or seconds <= 0 then
|
||||
return
|
||||
end
|
||||
|
||||
local start = time.monotonic()
|
||||
while time.monotonic() - start < seconds do
|
||||
-- Use coroutine.yield to avoid consuming CPU
|
||||
coroutine.yield()
|
||||
end
|
||||
end
|
||||
|
||||
_G.microtime = time.microtime
|
||||
|
||||
return time
|
@ -1,93 +0,0 @@
|
||||
-- timestamp.lua
|
||||
local timestamp = {}
|
||||
|
||||
-- Standard format presets using Lua format codes
|
||||
local FORMATS = {
|
||||
iso = "%Y-%m-%dT%H:%M:%SZ",
|
||||
datetime = "%Y-%m-%d %H:%M:%S",
|
||||
us_date = "%m/%d/%Y",
|
||||
us_datetime = "%m/%d/%Y %I:%M:%S %p",
|
||||
date = "%Y-%m-%d",
|
||||
time = "%H:%M:%S",
|
||||
time12 = "%I:%M:%S %p",
|
||||
readable = "%B %d, %Y %I:%M:%S %p",
|
||||
compact = "%Y%m%d_%H%M%S"
|
||||
}
|
||||
|
||||
-- Parse input to unix timestamp and microseconds
|
||||
local function parse_input(input)
|
||||
local unix_time, micros = 0, 0
|
||||
|
||||
if type(input) == "string" then
|
||||
local frac, secs = input:match("^(0%.%d+)%s+(%d+)$")
|
||||
if frac and secs then
|
||||
unix_time = tonumber(secs)
|
||||
micros = math.floor((tonumber(frac) * 1000000) + 0.5)
|
||||
else
|
||||
unix_time = tonumber(input) or 0
|
||||
end
|
||||
elseif type(input) == "number" then
|
||||
unix_time = math.floor(input)
|
||||
micros = math.floor(((input - unix_time) * 1000000) + 0.5)
|
||||
end
|
||||
|
||||
return unix_time, micros
|
||||
end
|
||||
|
||||
-- Remove leading zeros from number string
|
||||
local function no_leading_zero(s)
|
||||
return s:gsub("^0+", "") or "0"
|
||||
end
|
||||
|
||||
-- Main format function
|
||||
function timestamp.format(input, fmt)
|
||||
fmt = fmt or "datetime"
|
||||
local format_str = FORMATS[fmt] or fmt
|
||||
local unix_time, micros = parse_input(input)
|
||||
local result = os.date(format_str, unix_time)
|
||||
|
||||
-- Handle microseconds if format contains dot
|
||||
if format_str:find("%.") then
|
||||
result = result .. string.format(".%06d", micros)
|
||||
end
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
-- US date/time with no leading zeros
|
||||
function timestamp.us_datetime_no_zero(input)
|
||||
local unix_time, micros = parse_input(input)
|
||||
local month = no_leading_zero(os.date("%m", unix_time))
|
||||
local day = no_leading_zero(os.date("%d", unix_time))
|
||||
local year = os.date("%Y", unix_time)
|
||||
local hour = no_leading_zero(os.date("%I", unix_time))
|
||||
local min = os.date("%M", unix_time)
|
||||
local sec = os.date("%S", unix_time)
|
||||
local ampm = os.date("%p", unix_time)
|
||||
|
||||
return string.format("%s/%s/%s %s:%s:%s %s", month, day, year, hour, min, sec, ampm)
|
||||
end
|
||||
|
||||
-- Quick preset functions
|
||||
function timestamp.iso(input) return timestamp.format(input, "iso") end
|
||||
function timestamp.datetime(input) return timestamp.format(input, "datetime") end
|
||||
function timestamp.us_date(input) return timestamp.format(input, "us_date") end
|
||||
function timestamp.us_datetime(input) return timestamp.us_datetime_no_zero(input) end
|
||||
function timestamp.date(input) return timestamp.format(input, "date") end
|
||||
function timestamp.time(input) return timestamp.format(input, "time") end
|
||||
function timestamp.time12(input) return timestamp.format(input, "time12") end
|
||||
function timestamp.readable(input) return timestamp.format(input, "readable") end
|
||||
|
||||
-- Microsecond precision variants
|
||||
function timestamp.datetime_micro(input)
|
||||
return timestamp.format(input, "%Y-%m-%d %H:%M:%S.") .. string.format("%06d", select(2, parse_input(input)))
|
||||
end
|
||||
|
||||
function timestamp.iso_micro(input)
|
||||
return timestamp.format(input, "%Y-%m-%dT%H:%M:%S.") .. string.format("%06dZ", select(2, parse_input(input)))
|
||||
end
|
||||
|
||||
-- Register global convenience function
|
||||
_G.format_time = timestamp.format
|
||||
|
||||
return timestamp
|
@ -1,323 +0,0 @@
|
||||
-- util.lua
|
||||
|
||||
-- ======================================================================
|
||||
-- CORE UTILITY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Generate a random token
|
||||
function generate_token(length)
|
||||
return __generate_token(length or 32)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- HTML ENTITY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Convert special characters to HTML entities (like htmlspecialchars)
|
||||
function html_special_chars(str)
|
||||
if type(str) ~= "string" then
|
||||
return str
|
||||
end
|
||||
|
||||
return __html_special_chars(str)
|
||||
end
|
||||
|
||||
-- Convert all applicable characters to HTML entities (like htmlentities)
|
||||
function html_entities(str)
|
||||
if type(str) ~= "string" then
|
||||
return str
|
||||
end
|
||||
|
||||
return __html_entities(str)
|
||||
end
|
||||
|
||||
-- Convert HTML entities back to characters (simple version)
|
||||
function html_entity_decode(str)
|
||||
if type(str) ~= "string" then
|
||||
return str
|
||||
end
|
||||
|
||||
str = str:gsub("<", "<")
|
||||
str = str:gsub(">", ">")
|
||||
str = str:gsub(""", '"')
|
||||
str = str:gsub("'", "'")
|
||||
str = str:gsub("&", "&")
|
||||
|
||||
return str
|
||||
end
|
||||
|
||||
-- Convert newlines to <br> tags
|
||||
function nl2br(str)
|
||||
if type(str) ~= "string" then
|
||||
return str
|
||||
end
|
||||
|
||||
return str:gsub("\r\n", "<br>"):gsub("\n", "<br>"):gsub("\r", "<br>")
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- URL FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- URL encode a string
|
||||
function url_encode(str)
|
||||
if type(str) ~= "string" then
|
||||
return str
|
||||
end
|
||||
|
||||
str = str:gsub("\n", "\r\n")
|
||||
str = str:gsub("([^%w %-%_%.%~])", function(c)
|
||||
return string.format("%%%02X", string.byte(c))
|
||||
end)
|
||||
str = str:gsub(" ", "+")
|
||||
return str
|
||||
end
|
||||
|
||||
-- URL decode a string
|
||||
function url_decode(str)
|
||||
if type(str) ~= "string" then
|
||||
return str
|
||||
end
|
||||
|
||||
str = str:gsub("+", " ")
|
||||
str = str:gsub("%%(%x%x)", function(h)
|
||||
return string.char(tonumber(h, 16))
|
||||
end)
|
||||
return str
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- VALIDATION FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Email validation
|
||||
function is_email(str)
|
||||
if type(str) ~= "string" then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Simple email validation pattern
|
||||
local pattern = "^[%w%.%%%+%-]+@[%w%.%%%+%-]+%.%w%w%w?%w?$"
|
||||
return str:match(pattern) ~= nil
|
||||
end
|
||||
|
||||
-- URL validation
|
||||
function is_url(str)
|
||||
if type(str) ~= "string" then
|
||||
return false
|
||||
end
|
||||
|
||||
-- Simple URL validation
|
||||
local pattern = "^https?://[%w-_%.%?%.:/%+=&%%]+$"
|
||||
return str:match(pattern) ~= nil
|
||||
end
|
||||
|
||||
-- IP address validation (IPv4)
|
||||
function is_ipv4(str)
|
||||
if type(str) ~= "string" then
|
||||
return false
|
||||
end
|
||||
|
||||
local pattern = "^(%d%d?%d?)%.(%d%d?%d?)%.(%d%d?%d?)%.(%d%d?%d?)$"
|
||||
local a, b, c, d = str:match(pattern)
|
||||
|
||||
if not (a and b and c and d) then
|
||||
return false
|
||||
end
|
||||
|
||||
a, b, c, d = tonumber(a), tonumber(b), tonumber(c), tonumber(d)
|
||||
return a <= 255 and b <= 255 and c <= 255 and d <= 255
|
||||
end
|
||||
|
||||
-- Integer validation
|
||||
function is_int(str)
|
||||
if type(str) == "number" then
|
||||
return math.floor(str) == str
|
||||
elseif type(str) ~= "string" then
|
||||
return false
|
||||
end
|
||||
|
||||
return str:match("^-?%d+$") ~= nil
|
||||
end
|
||||
|
||||
-- Float validation
|
||||
function is_float(str)
|
||||
if type(str) == "number" then
|
||||
return true
|
||||
elseif type(str) ~= "string" then
|
||||
return false
|
||||
end
|
||||
|
||||
return str:match("^-?%d+%.?%d*$") ~= nil
|
||||
end
|
||||
|
||||
-- Boolean validation
|
||||
function is_bool(value)
|
||||
if type(value) == "boolean" then
|
||||
return true
|
||||
elseif type(value) ~= "string" and type(value) ~= "number" then
|
||||
return false
|
||||
end
|
||||
|
||||
local v = type(value) == "string" and value:lower() or value
|
||||
return v == "1" or v == "true" or v == "on" or v == "yes" or
|
||||
v == "0" or v == "false" or v == "off" or v == "no" or
|
||||
v == 1 or v == 0
|
||||
end
|
||||
|
||||
-- Convert to boolean
|
||||
function to_bool(value)
|
||||
if type(value) == "boolean" then
|
||||
return value
|
||||
elseif type(value) ~= "string" and type(value) ~= "number" then
|
||||
return false
|
||||
end
|
||||
|
||||
local v = type(value) == "string" and value:lower() or value
|
||||
return v == "1" or v == "true" or v == "on" or v == "yes" or v == 1
|
||||
end
|
||||
|
||||
-- Sanitize string (simple version)
|
||||
function sanitize_string(str)
|
||||
if type(str) ~= "string" then
|
||||
return ""
|
||||
end
|
||||
|
||||
return html_special_chars(str)
|
||||
end
|
||||
|
||||
-- Sanitize to integer
|
||||
function sanitize_int(value)
|
||||
if type(value) ~= "string" and type(value) ~= "number" then
|
||||
return 0
|
||||
end
|
||||
|
||||
value = tostring(value)
|
||||
local result = value:match("^-?%d+")
|
||||
return result and tonumber(result) or 0
|
||||
end
|
||||
|
||||
-- Sanitize to float
|
||||
function sanitize_float(value)
|
||||
if type(value) ~= "string" and type(value) ~= "number" then
|
||||
return 0
|
||||
end
|
||||
|
||||
value = tostring(value)
|
||||
local result = value:match("^-?%d+%.?%d*")
|
||||
return result and tonumber(result) or 0
|
||||
end
|
||||
|
||||
-- Sanitize URL
|
||||
function sanitize_url(str)
|
||||
if type(str) ~= "string" then
|
||||
return ""
|
||||
end
|
||||
|
||||
-- Basic sanitization by removing control characters
|
||||
str = str:gsub("[\000-\031]", "")
|
||||
|
||||
-- Make sure it's a valid URL
|
||||
if is_url(str) then
|
||||
return str
|
||||
end
|
||||
|
||||
-- Try to prepend http:// if it's missing
|
||||
if not str:match("^https?://") and is_url("http://" .. str) then
|
||||
return "http://" .. str
|
||||
end
|
||||
|
||||
return ""
|
||||
end
|
||||
|
||||
-- Sanitize email
|
||||
function sanitize_email(str)
|
||||
if type(str) ~= "string" then
|
||||
return ""
|
||||
end
|
||||
|
||||
-- Remove all characters except common email characters
|
||||
str = str:gsub("[^%a%d%!%#%$%%%&%'%*%+%-%/%=%?%^%_%`%{%|%}%~%@%.%[%]]", "")
|
||||
|
||||
-- Return only if it's a valid email
|
||||
if is_email(str) then
|
||||
return str
|
||||
end
|
||||
|
||||
return ""
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- SECURITY FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Basic XSS prevention
|
||||
function xss_clean(str)
|
||||
if type(str) ~= "string" then
|
||||
return str
|
||||
end
|
||||
|
||||
-- Convert problematic characters to entities
|
||||
local result = html_special_chars(str)
|
||||
|
||||
-- Remove JavaScript event handlers
|
||||
result = result:gsub("on%w+%s*=", "")
|
||||
|
||||
-- Remove JavaScript protocol
|
||||
result = result:gsub("javascript:", "")
|
||||
|
||||
-- Remove CSS expression
|
||||
result = result:gsub("expression%s*%(", "")
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
-- Base64 encode
|
||||
function base64_encode(str)
|
||||
if type(str) ~= "string" then
|
||||
return str
|
||||
end
|
||||
|
||||
return __base64_encode(str)
|
||||
end
|
||||
|
||||
-- Base64 decode
|
||||
function base64_decode(str)
|
||||
if type(str) ~= "string" then
|
||||
return str
|
||||
end
|
||||
|
||||
return __base64_decode(str)
|
||||
end
|
||||
|
||||
-- ======================================================================
|
||||
-- PASSWORD FUNCTIONS
|
||||
-- ======================================================================
|
||||
|
||||
-- Hash a password using Argon2id
|
||||
-- Options:
|
||||
-- memory: Amount of memory to use in KB (default: 128MB)
|
||||
-- iterations: Number of iterations (default: 4)
|
||||
-- parallelism: Number of threads (default: 4)
|
||||
-- salt_length: Length of salt in bytes (default: 16)
|
||||
-- key_length: Length of the derived key in bytes (default: 32)
|
||||
function password_hash(plain_password, options)
|
||||
if type(plain_password) ~= "string" then
|
||||
error("password_hash: expected string password", 2)
|
||||
end
|
||||
|
||||
return __password_hash(plain_password, options)
|
||||
end
|
||||
|
||||
-- Verify a password against a hash
|
||||
function password_verify(plain_password, hash_string)
|
||||
if type(plain_password) ~= "string" then
|
||||
error("password_verify: expected string password", 2)
|
||||
end
|
||||
|
||||
if type(hash_string) ~= "string" then
|
||||
error("password_verify: expected string hash", 2)
|
||||
end
|
||||
|
||||
return __password_verify(plain_password, hash_string)
|
||||
end
|
@ -1,463 +0,0 @@
|
||||
package lualibs
|
||||
|
||||
import (
|
||||
"Moonshark/logger"
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"hash"
|
||||
"math"
|
||||
mrand "math/rand/v2"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
var (
|
||||
// Map to store state-specific RNGs
|
||||
stateRngs = make(map[*luajit.State]*mrand.PCG)
|
||||
stateRngsMu sync.Mutex
|
||||
)
|
||||
|
||||
// RegisterCryptoFunctions registers all crypto functions with the Lua state
|
||||
func RegisterCryptoFunctions(state *luajit.State) error {
|
||||
// Create a state-specific RNG
|
||||
stateRngsMu.Lock()
|
||||
stateRngs[state] = mrand.NewPCG(uint64(time.Now().UnixNano()), uint64(time.Now().UnixNano()>>32))
|
||||
stateRngsMu.Unlock()
|
||||
|
||||
if err := state.RegisterGoFunction("__generate_token", generateToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register hash functions
|
||||
if err := state.RegisterGoFunction("__crypto_hash", cryptoHash); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register HMAC functions
|
||||
if err := state.RegisterGoFunction("__crypto_hmac", cryptoHmac); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register UUID generation
|
||||
if err := state.RegisterGoFunction("__crypto_uuid", cryptoUuid); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register random functions
|
||||
if err := state.RegisterGoFunction("__crypto_random", cryptoRandom); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__crypto_random_bytes", cryptoRandomBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__crypto_random_int", cryptoRandomInt); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__crypto_random_seed", cryptoRandomSeed); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Override Lua's math.random
|
||||
if err := OverrideLuaRandom(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupCrypto cleans up resources when a state is closed
|
||||
func CleanupCrypto(state *luajit.State) {
|
||||
stateRngsMu.Lock()
|
||||
delete(stateRngs, state)
|
||||
stateRngsMu.Unlock()
|
||||
}
|
||||
|
||||
// cryptoHash generates hash digests using various algorithms
|
||||
func cryptoHash(state *luajit.State) int {
|
||||
if err := state.CheckMinArgs(2); err != nil {
|
||||
return state.PushError("hash: %v", err)
|
||||
}
|
||||
|
||||
data, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("hash: data must be string")
|
||||
}
|
||||
|
||||
algorithm, err := state.SafeToString(2)
|
||||
if err != nil {
|
||||
return state.PushError("hash: algorithm must be string")
|
||||
}
|
||||
|
||||
var h hash.Hash
|
||||
switch algorithm {
|
||||
case "md5":
|
||||
h = md5.New()
|
||||
case "sha1":
|
||||
h = sha1.New()
|
||||
case "sha256":
|
||||
h = sha256.New()
|
||||
case "sha512":
|
||||
h = sha512.New()
|
||||
default:
|
||||
return state.PushError("unsupported algorithm: %s", algorithm)
|
||||
}
|
||||
|
||||
h.Write([]byte(data))
|
||||
hashBytes := h.Sum(nil)
|
||||
|
||||
// Output format
|
||||
outputFormat := "hex"
|
||||
if state.GetTop() >= 3 {
|
||||
if format, err := state.SafeToString(3); err == nil {
|
||||
outputFormat = format
|
||||
}
|
||||
}
|
||||
|
||||
switch outputFormat {
|
||||
case "hex":
|
||||
state.PushString(hex.EncodeToString(hashBytes))
|
||||
case "binary":
|
||||
state.PushString(string(hashBytes))
|
||||
default:
|
||||
state.PushString(hex.EncodeToString(hashBytes))
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// cryptoHmac generates HMAC using various hash algorithms
|
||||
func cryptoHmac(state *luajit.State) int {
|
||||
if err := state.CheckMinArgs(3); err != nil {
|
||||
return state.PushError("hmac: %v", err)
|
||||
}
|
||||
|
||||
data, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("hmac: data must be string")
|
||||
}
|
||||
|
||||
key, err := state.SafeToString(2)
|
||||
if err != nil {
|
||||
return state.PushError("hmac: key must be string")
|
||||
}
|
||||
|
||||
algorithm, err := state.SafeToString(3)
|
||||
if err != nil {
|
||||
return state.PushError("hmac: algorithm must be string")
|
||||
}
|
||||
|
||||
var h func() hash.Hash
|
||||
switch algorithm {
|
||||
case "md5":
|
||||
h = md5.New
|
||||
case "sha1":
|
||||
h = sha1.New
|
||||
case "sha256":
|
||||
h = sha256.New
|
||||
case "sha512":
|
||||
h = sha512.New
|
||||
default:
|
||||
return state.PushError("unsupported algorithm: %s", algorithm)
|
||||
}
|
||||
|
||||
mac := hmac.New(h, []byte(key))
|
||||
mac.Write([]byte(data))
|
||||
macBytes := mac.Sum(nil)
|
||||
|
||||
// Output format
|
||||
outputFormat := "hex"
|
||||
if state.GetTop() >= 4 {
|
||||
if format, err := state.SafeToString(4); err == nil {
|
||||
outputFormat = format
|
||||
}
|
||||
}
|
||||
|
||||
switch outputFormat {
|
||||
case "hex":
|
||||
state.PushString(hex.EncodeToString(macBytes))
|
||||
case "binary":
|
||||
state.PushString(string(macBytes))
|
||||
default:
|
||||
state.PushString(hex.EncodeToString(macBytes))
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// cryptoUuid generates a random UUID v4
|
||||
func cryptoUuid(state *luajit.State) int {
|
||||
uuid := make([]byte, 16)
|
||||
if _, err := rand.Read(uuid); err != nil {
|
||||
return state.PushError("uuid: generation error: %v", err)
|
||||
}
|
||||
|
||||
// Set version (4) and variant (RFC 4122)
|
||||
uuid[6] = (uuid[6] & 0x0F) | 0x40
|
||||
uuid[8] = (uuid[8] & 0x3F) | 0x80
|
||||
|
||||
uuidStr := fmt.Sprintf("%x-%x-%x-%x-%x",
|
||||
uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:])
|
||||
|
||||
state.PushString(uuidStr)
|
||||
return 1
|
||||
}
|
||||
|
||||
// cryptoRandomBytes generates random bytes
|
||||
func cryptoRandomBytes(state *luajit.State) int {
|
||||
if err := state.CheckMinArgs(1); err != nil {
|
||||
return state.PushError("random_bytes: %v", err)
|
||||
}
|
||||
|
||||
length, err := state.SafeToNumber(1)
|
||||
if err != nil {
|
||||
return state.PushError("random_bytes: length must be number")
|
||||
}
|
||||
|
||||
if length <= 0 {
|
||||
return state.PushError("random_bytes: length must be positive")
|
||||
}
|
||||
|
||||
// Check if secure
|
||||
secure := true
|
||||
if state.GetTop() >= 2 && state.IsBoolean(2) {
|
||||
secure = state.ToBoolean(2)
|
||||
}
|
||||
|
||||
bytes := make([]byte, int(length))
|
||||
|
||||
if secure {
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return state.PushError("random_bytes: error: %v", err)
|
||||
}
|
||||
} else {
|
||||
stateRngsMu.Lock()
|
||||
stateRng, ok := stateRngs[state]
|
||||
stateRngsMu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return state.PushError("random_bytes: RNG not initialized")
|
||||
}
|
||||
|
||||
for i := range bytes {
|
||||
bytes[i] = byte(stateRng.Uint64() & 0xFF)
|
||||
}
|
||||
}
|
||||
|
||||
// Output format
|
||||
outputFormat := "binary"
|
||||
if state.GetTop() >= 3 {
|
||||
if format, err := state.SafeToString(3); err == nil {
|
||||
outputFormat = format
|
||||
}
|
||||
}
|
||||
|
||||
switch outputFormat {
|
||||
case "binary":
|
||||
state.PushString(string(bytes))
|
||||
case "hex":
|
||||
state.PushString(hex.EncodeToString(bytes))
|
||||
default:
|
||||
state.PushString(string(bytes))
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// cryptoRandomInt generates a random integer in range [min, max]
|
||||
func cryptoRandomInt(state *luajit.State) int {
|
||||
if err := state.CheckMinArgs(2); err != nil {
|
||||
return state.PushError("random_int: %v", err)
|
||||
}
|
||||
|
||||
minVal, err := state.SafeToNumber(1)
|
||||
if err != nil {
|
||||
return state.PushError("random_int: min must be number")
|
||||
}
|
||||
|
||||
maxVal, err := state.SafeToNumber(2)
|
||||
if err != nil {
|
||||
return state.PushError("random_int: max must be number")
|
||||
}
|
||||
|
||||
min := int64(minVal)
|
||||
max := int64(maxVal)
|
||||
|
||||
if max <= min {
|
||||
return state.PushError("random_int: max must be greater than min")
|
||||
}
|
||||
|
||||
// Check if secure
|
||||
secure := true
|
||||
if state.GetTop() >= 3 && state.IsBoolean(3) {
|
||||
secure = state.ToBoolean(3)
|
||||
}
|
||||
|
||||
range_size := max - min + 1
|
||||
var result int64
|
||||
|
||||
if secure {
|
||||
bytes := make([]byte, 8)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return state.PushError("random_int: error: %v", err)
|
||||
}
|
||||
|
||||
val := binary.BigEndian.Uint64(bytes)
|
||||
result = min + int64(val%uint64(range_size))
|
||||
} else {
|
||||
stateRngsMu.Lock()
|
||||
stateRng, ok := stateRngs[state]
|
||||
stateRngsMu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return state.PushError("random_int: RNG not initialized")
|
||||
}
|
||||
|
||||
result = min + int64(stateRng.Uint64()%uint64(range_size))
|
||||
}
|
||||
|
||||
state.PushNumber(float64(result))
|
||||
return 1
|
||||
}
|
||||
|
||||
// cryptoRandom implements math.random functionality
|
||||
func cryptoRandom(state *luajit.State) int {
|
||||
numArgs := state.GetTop()
|
||||
|
||||
// Check if secure
|
||||
secure := false
|
||||
|
||||
// math.random() - return [0,1)
|
||||
if numArgs == 0 {
|
||||
if secure {
|
||||
bytes := make([]byte, 8)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return state.PushError("random: error: %v", err)
|
||||
}
|
||||
val := binary.BigEndian.Uint64(bytes)
|
||||
state.PushNumber(float64(val) / float64(math.MaxUint64))
|
||||
} else {
|
||||
stateRngsMu.Lock()
|
||||
stateRng, ok := stateRngs[state]
|
||||
stateRngsMu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return state.PushError("random: RNG not initialized")
|
||||
}
|
||||
|
||||
state.PushNumber(float64(stateRng.Uint64()) / float64(math.MaxUint64))
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// math.random(n) - return integer [1,n]
|
||||
if numArgs == 1 && state.IsNumber(1) {
|
||||
n := int64(state.ToNumber(1))
|
||||
if n < 1 {
|
||||
return state.PushError("random: upper bound must be >= 1")
|
||||
}
|
||||
|
||||
state.PushNumber(1) // min
|
||||
state.PushNumber(float64(n)) // max
|
||||
state.PushBoolean(secure) // secure flag
|
||||
return cryptoRandomInt(state)
|
||||
}
|
||||
|
||||
// math.random(m, n) - return integer [m,n]
|
||||
if numArgs >= 2 && state.IsNumber(1) && state.IsNumber(2) {
|
||||
state.PushBoolean(secure) // secure flag
|
||||
return cryptoRandomInt(state)
|
||||
}
|
||||
|
||||
return state.PushError("random: invalid arguments")
|
||||
}
|
||||
|
||||
// cryptoRandomSeed sets seed for non-secure RNG
|
||||
func cryptoRandomSeed(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(1); err != nil {
|
||||
return state.PushError("randomseed: %v", err)
|
||||
}
|
||||
|
||||
seed, err := state.SafeToNumber(1)
|
||||
if err != nil {
|
||||
return state.PushError("randomseed: seed must be number")
|
||||
}
|
||||
|
||||
seedVal := uint64(seed)
|
||||
stateRngsMu.Lock()
|
||||
stateRngs[state] = mrand.NewPCG(seedVal, seedVal>>32)
|
||||
stateRngsMu.Unlock()
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// OverrideLuaRandom replaces Lua's math.random with Go implementation
|
||||
func OverrideLuaRandom(state *luajit.State) error {
|
||||
if err := state.RegisterGoFunction("go_math_random", cryptoRandom); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := state.RegisterGoFunction("go_math_randomseed", cryptoRandomSeed); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Replace original functions
|
||||
return state.DoString(`
|
||||
-- Save original functions
|
||||
_G._original_math_random = math.random
|
||||
_G._original_math_randomseed = math.randomseed
|
||||
|
||||
-- Replace with Go implementations
|
||||
math.random = go_math_random
|
||||
math.randomseed = go_math_randomseed
|
||||
|
||||
-- Clean up global namespace
|
||||
go_math_random = nil
|
||||
go_math_randomseed = nil
|
||||
`)
|
||||
}
|
||||
|
||||
// generateToken creates a cryptographically secure random token
|
||||
func generateToken(state *luajit.State) int {
|
||||
// Get the length from the Lua arguments (default to 32)
|
||||
length := 32
|
||||
if state.GetTop() >= 1 {
|
||||
if lengthVal, err := state.SafeToNumber(1); err == nil {
|
||||
length = int(lengthVal)
|
||||
}
|
||||
}
|
||||
|
||||
// Enforce minimum length for security
|
||||
if length < 16 {
|
||||
length = 16
|
||||
}
|
||||
|
||||
// Generate secure random bytes
|
||||
tokenBytes := make([]byte, length)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
logger.Errorf("Failed to generate secure token: %v", err)
|
||||
state.PushString("")
|
||||
return 1 // Return empty string on error
|
||||
}
|
||||
|
||||
// Encode as base64
|
||||
token := base64.RawURLEncoding.EncodeToString(tokenBytes)
|
||||
|
||||
// Trim to requested length (base64 might be longer)
|
||||
if len(token) > length {
|
||||
token = token[:length]
|
||||
}
|
||||
|
||||
// Push the token to the Lua stack
|
||||
state.PushString(token)
|
||||
return 1 // One return value
|
||||
}
|
@ -1,333 +0,0 @@
|
||||
package lualibs
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"Moonshark/logger"
|
||||
|
||||
"git.sharkk.net/Go/Color"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// EnvManager handles loading, storing, and saving environment variables
|
||||
type EnvManager struct {
|
||||
envPath string // Path to .env file
|
||||
vars map[string]any // Environment variables in memory
|
||||
mu sync.RWMutex // Thread-safe access
|
||||
}
|
||||
|
||||
// Global environment manager instance
|
||||
var globalEnvManager *EnvManager
|
||||
|
||||
// InitEnv initializes the environment manager with the given data directory
|
||||
func InitEnv(dataDir string) error {
|
||||
if dataDir == "" {
|
||||
return fmt.Errorf("data directory cannot be empty")
|
||||
}
|
||||
|
||||
// Create data directory if it doesn't exist
|
||||
if err := os.MkdirAll(dataDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create data directory: %w", err)
|
||||
}
|
||||
|
||||
envPath := filepath.Join(dataDir, ".env")
|
||||
|
||||
globalEnvManager = &EnvManager{
|
||||
envPath: envPath,
|
||||
vars: make(map[string]any),
|
||||
}
|
||||
|
||||
// Load existing .env file if it exists
|
||||
if err := globalEnvManager.load(); err != nil {
|
||||
logger.Warnf("Failed to load .env file: %v", err)
|
||||
}
|
||||
|
||||
count := len(globalEnvManager.vars)
|
||||
if count > 0 {
|
||||
logger.Infof("Environment loaded: %s vars from %s",
|
||||
color.Yellow(fmt.Sprintf("%d", count)),
|
||||
color.Yellow(envPath))
|
||||
} else {
|
||||
logger.Infof("Environment initialized: %s", color.Yellow(envPath))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetGlobalEnvManager returns the global environment manager instance
|
||||
func GetGlobalEnvManager() *EnvManager {
|
||||
return globalEnvManager
|
||||
}
|
||||
|
||||
// parseValue attempts to parse a string value into the appropriate type
|
||||
func parseValue(value string) any {
|
||||
// Try boolean first
|
||||
if value == "true" {
|
||||
return true
|
||||
}
|
||||
if value == "false" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Try number
|
||||
if num, err := strconv.ParseFloat(value, 64); err == nil {
|
||||
// Check if it's actually an integer
|
||||
if num == float64(int64(num)) {
|
||||
return int64(num)
|
||||
}
|
||||
return num
|
||||
}
|
||||
|
||||
// Default to string
|
||||
return value
|
||||
}
|
||||
|
||||
// load reads the .env file and populates the vars map
|
||||
func (e *EnvManager) load() error {
|
||||
file, err := os.Open(e.envPath)
|
||||
if os.IsNotExist(err) {
|
||||
// File doesn't exist, start with empty env
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open .env file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
lineNum := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
lineNum++
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
// Skip empty lines and comments
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse key=value
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
logger.Warnf("Invalid .env line %d: %s", lineNum, line)
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
|
||||
// Remove quotes if present
|
||||
if len(value) >= 2 {
|
||||
if (strings.HasPrefix(value, "\"") && strings.HasSuffix(value, "\"")) ||
|
||||
(strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")) {
|
||||
value = value[1 : len(value)-1]
|
||||
}
|
||||
}
|
||||
|
||||
e.vars[key] = parseValue(value)
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
// Save writes the current environment variables to the .env file
|
||||
func (e *EnvManager) Save() error {
|
||||
if e == nil {
|
||||
return nil // No env manager initialized
|
||||
}
|
||||
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
|
||||
file, err := os.Create(e.envPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create .env file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Sort keys for consistent output
|
||||
keys := make([]string, 0, len(e.vars))
|
||||
for key := range e.vars {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
// Write header comment
|
||||
fmt.Fprintln(file, "# env variables - generated automatically - you can edit this file")
|
||||
fmt.Fprintln(file)
|
||||
|
||||
// Write each variable
|
||||
for _, key := range keys {
|
||||
value := e.vars[key]
|
||||
|
||||
// Convert value to string
|
||||
var strValue string
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
strValue = v
|
||||
case bool:
|
||||
strValue = strconv.FormatBool(v)
|
||||
case int64:
|
||||
strValue = strconv.FormatInt(v, 10)
|
||||
case float64:
|
||||
strValue = strconv.FormatFloat(v, 'g', -1, 64)
|
||||
case nil:
|
||||
continue // Skip nil values
|
||||
default:
|
||||
strValue = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
// Quote values that contain spaces or special characters
|
||||
if strings.ContainsAny(strValue, " \t\n\r\"'\\") {
|
||||
strValue = fmt.Sprintf("\"%s\"", strings.ReplaceAll(strValue, "\"", "\\\""))
|
||||
}
|
||||
|
||||
fmt.Fprintf(file, "%s=%s\n", key, strValue)
|
||||
}
|
||||
|
||||
logger.Debugf("Environment saved: %d vars to %s", len(e.vars), e.envPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves an environment variable
|
||||
func (e *EnvManager) Get(key string) (any, bool) {
|
||||
if e == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
|
||||
value, exists := e.vars[key]
|
||||
return value, exists
|
||||
}
|
||||
|
||||
// Set stores an environment variable
|
||||
func (e *EnvManager) Set(key string, value any) {
|
||||
if e == nil {
|
||||
return
|
||||
}
|
||||
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
e.vars[key] = value
|
||||
}
|
||||
|
||||
// GetAll returns a copy of all environment variables
|
||||
func (e *EnvManager) GetAll() map[string]any {
|
||||
if e == nil {
|
||||
return make(map[string]any)
|
||||
}
|
||||
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
|
||||
result := make(map[string]any, len(e.vars))
|
||||
for k, v := range e.vars {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// CleanupEnv saves the environment and cleans up resources
|
||||
func CleanupEnv() error {
|
||||
if globalEnvManager != nil {
|
||||
return globalEnvManager.Save()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// envGet Lua function to get an environment variable
|
||||
func envGet(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(1); err != nil {
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
key, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
if value, exists := globalEnvManager.Get(key); exists {
|
||||
if err := state.PushValue(value); err != nil {
|
||||
state.PushNil()
|
||||
}
|
||||
} else {
|
||||
state.PushNil()
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// envSet Lua function to set an environment variable
|
||||
func envSet(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(2); err != nil {
|
||||
state.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
key, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
state.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
// Handle different value types from Lua
|
||||
var value any
|
||||
switch state.GetType(2) {
|
||||
case luajit.TypeBoolean:
|
||||
value = state.ToBoolean(2)
|
||||
case luajit.TypeNumber:
|
||||
value = state.ToNumber(2)
|
||||
case luajit.TypeString:
|
||||
value = state.ToString(2)
|
||||
default:
|
||||
// Try to convert to string as fallback
|
||||
if str, err := state.SafeToString(2); err == nil {
|
||||
value = str
|
||||
} else {
|
||||
state.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
globalEnvManager.Set(key, value)
|
||||
state.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
// envGetAll Lua function to get all environment variables
|
||||
func envGetAll(state *luajit.State) int {
|
||||
vars := globalEnvManager.GetAll()
|
||||
if err := state.PushValue(vars); err != nil {
|
||||
state.PushNil()
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// RegisterEnvFunctions registers environment functions with the Lua state
|
||||
func RegisterEnvFunctions(state *luajit.State) error {
|
||||
if err := state.RegisterGoFunction("__env_get", envGet); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__env_set", envSet); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__env_get_all", envGetAll); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
@ -1,572 +0,0 @@
|
||||
package lualibs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"Moonshark/logger"
|
||||
|
||||
"git.sharkk.net/Go/Color"
|
||||
|
||||
lru "git.sharkk.net/Go/LRU"
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"github.com/golang/snappy"
|
||||
)
|
||||
|
||||
// Global filesystem path (set during initialization)
|
||||
var fsBasePath string
|
||||
|
||||
// Global file cache with compressed data
|
||||
var fileCache *lru.LRUCache
|
||||
|
||||
// Cache entry info for statistics/debugging
|
||||
type cacheStats struct {
|
||||
hits int64
|
||||
misses int64
|
||||
}
|
||||
|
||||
var stats cacheStats
|
||||
|
||||
// InitFS initializes the filesystem with the given base path
|
||||
func InitFS(basePath string) error {
|
||||
if basePath == "" {
|
||||
return errors.New("filesystem base path cannot be empty")
|
||||
}
|
||||
|
||||
// Create the directory if it doesn't exist
|
||||
if err := os.MkdirAll(basePath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create filesystem directory: %w", err)
|
||||
}
|
||||
|
||||
// Store the absolute path
|
||||
absPath, err := filepath.Abs(basePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get absolute path: %w", err)
|
||||
}
|
||||
|
||||
fsBasePath = absPath
|
||||
|
||||
// Initialize file cache with 2000 entries (reasonable for most use cases)
|
||||
fileCache = lru.NewLRUCache(2000)
|
||||
|
||||
logger.Infof("Filesystem is g2g! %s", color.Yellow(fsBasePath))
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupFS performs any necessary cleanup
|
||||
func CleanupFS() {
|
||||
if fileCache != nil {
|
||||
fileCache.Clear()
|
||||
logger.Infof(
|
||||
"File cache cleared - %s hits, %s misses",
|
||||
color.Yellow(fmt.Sprintf("%d", stats.hits)),
|
||||
color.Red(fmt.Sprintf("%d", stats.misses)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ResolvePath resolves a given path relative to the filesystem base
|
||||
// Returns the actual path and an error if the path tries to escape the sandbox
|
||||
func ResolvePath(path string) (string, error) {
|
||||
if fsBasePath == "" {
|
||||
return "", errors.New("filesystem not initialized")
|
||||
}
|
||||
|
||||
// Clean the path to remove any .. or . components
|
||||
cleanPath := filepath.Clean(path)
|
||||
|
||||
// Replace backslashes with forward slashes for consistent handling
|
||||
cleanPath = strings.ReplaceAll(cleanPath, "\\", "/")
|
||||
|
||||
// Remove any leading / or drive letter to make it relative
|
||||
cleanPath = strings.TrimPrefix(cleanPath, "/")
|
||||
|
||||
// Remove drive letter on Windows (e.g. C:)
|
||||
if len(cleanPath) >= 2 && cleanPath[1] == ':' {
|
||||
cleanPath = cleanPath[2:]
|
||||
}
|
||||
|
||||
// Ensure the path doesn't contain .. to prevent escaping
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return "", errors.New("path cannot contain .. components")
|
||||
}
|
||||
|
||||
// Join with the base path
|
||||
fullPath := filepath.Join(fsBasePath, cleanPath)
|
||||
|
||||
// Verify the path is still within the base directory
|
||||
if !strings.HasPrefix(fullPath, fsBasePath) {
|
||||
return "", errors.New("path escapes the filesystem sandbox")
|
||||
}
|
||||
|
||||
return fullPath, nil
|
||||
}
|
||||
|
||||
// getCacheKey creates a cache key from path and modification time
|
||||
func getCacheKey(fullPath string, modTime time.Time) string {
|
||||
return fmt.Sprintf("%s:%d", fullPath, modTime.Unix())
|
||||
}
|
||||
|
||||
// fsReadFile reads a file and returns its contents
|
||||
func fsReadFile(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(1); err != nil {
|
||||
return state.PushError("fs.read_file: %v", err)
|
||||
}
|
||||
|
||||
path, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("fs.read_file: path must be string")
|
||||
}
|
||||
|
||||
fullPath, err := ResolvePath(path)
|
||||
if err != nil {
|
||||
return state.PushError("fs.read_file: %v", err)
|
||||
}
|
||||
|
||||
// Get file info for cache key and validation
|
||||
info, err := os.Stat(fullPath)
|
||||
if err != nil {
|
||||
return state.PushError("fs.read_file: %v", err)
|
||||
}
|
||||
|
||||
// Create cache key with path and modification time
|
||||
cacheKey := getCacheKey(fullPath, info.ModTime())
|
||||
|
||||
// Try to get from cache first
|
||||
if fileCache != nil {
|
||||
if cachedData, exists := fileCache.Get(cacheKey); exists {
|
||||
if compressedData, ok := cachedData.([]byte); ok {
|
||||
// Decompress cached data
|
||||
data, err := snappy.Decode(nil, compressedData)
|
||||
if err == nil {
|
||||
stats.hits++
|
||||
state.PushString(string(data))
|
||||
return 1
|
||||
}
|
||||
// Cache corruption - continue to disk read
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cache miss or error - read from disk
|
||||
stats.misses++
|
||||
data, err := os.ReadFile(fullPath)
|
||||
if err != nil {
|
||||
return state.PushError("fs.read_file: %v", err)
|
||||
}
|
||||
|
||||
// Compress and cache the data
|
||||
if fileCache != nil {
|
||||
compressedData := snappy.Encode(nil, data)
|
||||
fileCache.Put(cacheKey, compressedData)
|
||||
}
|
||||
|
||||
state.PushString(string(data))
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsWriteFile writes data to a file
|
||||
func fsWriteFile(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(2); err != nil {
|
||||
return state.PushError("fs.write_file: %v", err)
|
||||
}
|
||||
|
||||
path, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("fs.write_file: path must be string")
|
||||
}
|
||||
|
||||
content, err := state.SafeToString(2)
|
||||
if err != nil {
|
||||
return state.PushError("fs.write_file: content must be string")
|
||||
}
|
||||
|
||||
fullPath, err := ResolvePath(path)
|
||||
if err != nil {
|
||||
return state.PushError("fs.write_file: %v", err)
|
||||
}
|
||||
|
||||
// Ensure the directory exists
|
||||
dir := filepath.Dir(fullPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return state.PushError("fs.write_file: failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(fullPath, []byte(content), 0644); err != nil {
|
||||
return state.PushError("fs.write_file: %v", err)
|
||||
}
|
||||
|
||||
state.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsAppendFile appends data to a file
|
||||
func fsAppendFile(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(2); err != nil {
|
||||
return state.PushError("fs.append_file: %v", err)
|
||||
}
|
||||
|
||||
path, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("fs.append_file: path must be string")
|
||||
}
|
||||
|
||||
content, err := state.SafeToString(2)
|
||||
if err != nil {
|
||||
return state.PushError("fs.append_file: content must be string")
|
||||
}
|
||||
|
||||
fullPath, err := ResolvePath(path)
|
||||
if err != nil {
|
||||
return state.PushError("fs.append_file: %v", err)
|
||||
}
|
||||
|
||||
// Ensure the directory exists
|
||||
dir := filepath.Dir(fullPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return state.PushError("fs.append_file: failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(fullPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return state.PushError("fs.append_file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if _, err = file.Write([]byte(content)); err != nil {
|
||||
return state.PushError("fs.append_file: %v", err)
|
||||
}
|
||||
|
||||
state.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsExists checks if a file or directory exists
|
||||
func fsExists(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(1); err != nil {
|
||||
return state.PushError("fs.exists: %v", err)
|
||||
}
|
||||
|
||||
path, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("fs.exists: path must be string")
|
||||
}
|
||||
|
||||
fullPath, err := ResolvePath(path)
|
||||
if err != nil {
|
||||
return state.PushError("fs.exists: %v", err)
|
||||
}
|
||||
|
||||
_, err = os.Stat(fullPath)
|
||||
state.PushBoolean(err == nil)
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsRemoveFile removes a file
|
||||
func fsRemoveFile(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(1); err != nil {
|
||||
return state.PushError("fs.remove_file: %v", err)
|
||||
}
|
||||
|
||||
path, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("fs.remove_file: path must be string")
|
||||
}
|
||||
|
||||
fullPath, err := ResolvePath(path)
|
||||
if err != nil {
|
||||
return state.PushError("fs.remove_file: %v", err)
|
||||
}
|
||||
|
||||
// Check if it's a directory
|
||||
info, err := os.Stat(fullPath)
|
||||
if err != nil {
|
||||
return state.PushError("fs.remove_file: %v", err)
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return state.PushError("fs.remove_file: cannot remove directory, use remove_dir instead")
|
||||
}
|
||||
|
||||
if err := os.Remove(fullPath); err != nil {
|
||||
return state.PushError("fs.remove_file: %v", err)
|
||||
}
|
||||
|
||||
state.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsGetInfo gets information about a file
|
||||
func fsGetInfo(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(1); err != nil {
|
||||
return state.PushError("fs.get_info: %v", err)
|
||||
}
|
||||
|
||||
path, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("fs.get_info: path must be string")
|
||||
}
|
||||
|
||||
fullPath, err := ResolvePath(path)
|
||||
if err != nil {
|
||||
return state.PushError("fs.get_info: %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(fullPath)
|
||||
if err != nil {
|
||||
return state.PushError("fs.get_info: %v", err)
|
||||
}
|
||||
|
||||
fileInfo := map[string]any{
|
||||
"name": info.Name(),
|
||||
"size": info.Size(),
|
||||
"mode": int(info.Mode()),
|
||||
"mod_time": info.ModTime().Unix(),
|
||||
"is_dir": info.IsDir(),
|
||||
}
|
||||
|
||||
if err := state.PushValue(fileInfo); err != nil {
|
||||
return state.PushError("fs.get_info: %v", err)
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsMakeDir creates a directory
|
||||
func fsMakeDir(state *luajit.State) int {
|
||||
if err := state.CheckMinArgs(1); err != nil {
|
||||
return state.PushError("fs.make_dir: %v", err)
|
||||
}
|
||||
|
||||
path, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("fs.make_dir: path must be string")
|
||||
}
|
||||
|
||||
perm := os.FileMode(0755)
|
||||
if state.GetTop() >= 2 {
|
||||
if permVal, err := state.SafeToNumber(2); err == nil {
|
||||
perm = os.FileMode(permVal)
|
||||
}
|
||||
}
|
||||
|
||||
fullPath, err := ResolvePath(path)
|
||||
if err != nil {
|
||||
return state.PushError("fs.make_dir: %v", err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(fullPath, perm); err != nil {
|
||||
return state.PushError("fs.make_dir: %v", err)
|
||||
}
|
||||
|
||||
state.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsListDir lists the contents of a directory
|
||||
func fsListDir(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(1); err != nil {
|
||||
return state.PushError("fs.list_dir: %v", err)
|
||||
}
|
||||
|
||||
path, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("fs.list_dir: path must be string")
|
||||
}
|
||||
|
||||
fullPath, err := ResolvePath(path)
|
||||
if err != nil {
|
||||
return state.PushError("fs.list_dir: %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(fullPath)
|
||||
if err != nil {
|
||||
return state.PushError("fs.list_dir: %v", err)
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
return state.PushError("fs.list_dir: not a directory")
|
||||
}
|
||||
|
||||
files, err := os.ReadDir(fullPath)
|
||||
if err != nil {
|
||||
return state.PushError("fs.list_dir: %v", err)
|
||||
}
|
||||
|
||||
// Create array of filenames
|
||||
filenames := make([]string, len(files))
|
||||
for i, file := range files {
|
||||
filenames[i] = file.Name()
|
||||
}
|
||||
|
||||
if err := state.PushValue(filenames); err != nil {
|
||||
return state.PushError("fs.list_dir: %v", err)
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsRemoveDir removes a directory
|
||||
func fsRemoveDir(state *luajit.State) int {
|
||||
if err := state.CheckMinArgs(1); err != nil {
|
||||
return state.PushError("fs.remove_dir: %v", err)
|
||||
}
|
||||
|
||||
path, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("fs.remove_dir: path must be string")
|
||||
}
|
||||
|
||||
recursive := false
|
||||
if state.GetTop() >= 2 {
|
||||
recursive = state.ToBoolean(2)
|
||||
}
|
||||
|
||||
fullPath, err := ResolvePath(path)
|
||||
if err != nil {
|
||||
return state.PushError("fs.remove_dir: %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(fullPath)
|
||||
if err != nil {
|
||||
return state.PushError("fs.remove_dir: %v", err)
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
return state.PushError("fs.remove_dir: not a directory")
|
||||
}
|
||||
|
||||
if recursive {
|
||||
err = os.RemoveAll(fullPath)
|
||||
} else {
|
||||
err = os.Remove(fullPath)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return state.PushError("fs.remove_dir: %v", err)
|
||||
}
|
||||
|
||||
state.PushBoolean(true)
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsJoinPaths joins path components
|
||||
func fsJoinPaths(state *luajit.State) int {
|
||||
if err := state.CheckMinArgs(1); err != nil {
|
||||
return state.PushError("fs.join_paths: %v", err)
|
||||
}
|
||||
|
||||
components := make([]string, state.GetTop())
|
||||
for i := 1; i <= state.GetTop(); i++ {
|
||||
comp, err := state.SafeToString(i)
|
||||
if err != nil {
|
||||
return state.PushError("fs.join_paths: all arguments must be strings")
|
||||
}
|
||||
components[i-1] = comp
|
||||
}
|
||||
|
||||
result := filepath.Join(components...)
|
||||
result = strings.ReplaceAll(result, "\\", "/")
|
||||
|
||||
state.PushString(result)
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsDirName returns the directory portion of a path
|
||||
func fsDirName(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(1); err != nil {
|
||||
return state.PushError("fs.dir_name: %v", err)
|
||||
}
|
||||
|
||||
path, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("fs.dir_name: path must be string")
|
||||
}
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
dir = strings.ReplaceAll(dir, "\\", "/")
|
||||
|
||||
state.PushString(dir)
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsBaseName returns the file name portion of a path
|
||||
func fsBaseName(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(1); err != nil {
|
||||
return state.PushError("fs.base_name: %v", err)
|
||||
}
|
||||
|
||||
path, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("fs.base_name: path must be string")
|
||||
}
|
||||
|
||||
base := filepath.Base(path)
|
||||
state.PushString(base)
|
||||
return 1
|
||||
}
|
||||
|
||||
// fsExtension returns the file extension
|
||||
func fsExtension(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(1); err != nil {
|
||||
return state.PushError("fs.extension: %v", err)
|
||||
}
|
||||
|
||||
path, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("fs.extension: path must be string")
|
||||
}
|
||||
|
||||
ext := filepath.Ext(path)
|
||||
state.PushString(ext)
|
||||
return 1
|
||||
}
|
||||
|
||||
// RegisterFSFunctions registers filesystem functions with the Lua state
|
||||
func RegisterFSFunctions(state *luajit.State) error {
|
||||
if err := state.RegisterGoFunction("__fs_read_file", fsReadFile); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_write_file", fsWriteFile); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_append_file", fsAppendFile); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_exists", fsExists); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_remove_file", fsRemoveFile); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_get_info", fsGetInfo); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_make_dir", fsMakeDir); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_list_dir", fsListDir); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_remove_dir", fsRemoveDir); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_join_paths", fsJoinPaths); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_dir_name", fsDirName); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_base_name", fsBaseName); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__fs_extension", fsExtension); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -1,227 +0,0 @@
|
||||
package lualibs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/valyala/bytebufferpool"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Default HTTP client with sensible timeout
|
||||
var defaultFastClient = fasthttp.Client{
|
||||
MaxConnsPerHost: 1024,
|
||||
MaxIdleConnDuration: time.Minute,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
DisableHeaderNamesNormalizing: true,
|
||||
}
|
||||
|
||||
// HTTPClientConfig contains client settings
|
||||
type HTTPClientConfig struct {
|
||||
MaxTimeout time.Duration // Maximum timeout for requests (0 = no limit)
|
||||
DefaultTimeout time.Duration // Default request timeout
|
||||
MaxResponseSize int64 // Maximum response size in bytes (0 = no limit)
|
||||
AllowRemote bool // Whether to allow remote connections
|
||||
}
|
||||
|
||||
// DefaultHTTPClientConfig provides sensible defaults
|
||||
var DefaultHTTPClientConfig = HTTPClientConfig{
|
||||
MaxTimeout: 60 * time.Second,
|
||||
DefaultTimeout: 30 * time.Second,
|
||||
MaxResponseSize: 10 * 1024 * 1024, // 10MB
|
||||
AllowRemote: true,
|
||||
}
|
||||
|
||||
// RegisterHttpFunctions registers HTTP functions with the Lua state
|
||||
func RegisterHttpFunctions(state *luajit.State) error {
|
||||
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// httpRequest makes an HTTP request and returns the result to Lua
|
||||
func httpRequest(state *luajit.State) int {
|
||||
if err := state.CheckMinArgs(2); err != nil {
|
||||
return state.PushError("http.client.request: %v", err)
|
||||
}
|
||||
|
||||
// Get method and URL
|
||||
method, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("http.client.request: method must be string")
|
||||
}
|
||||
method = strings.ToUpper(method)
|
||||
|
||||
urlStr, err := state.SafeToString(2)
|
||||
if err != nil {
|
||||
return state.PushError("http.client.request: url must be string")
|
||||
}
|
||||
|
||||
// Parse URL to check if it's valid
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return state.PushError("Invalid URL: %v", err)
|
||||
}
|
||||
|
||||
// Get client configuration
|
||||
config := DefaultHTTPClientConfig
|
||||
|
||||
// Check if remote connections are allowed
|
||||
if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") {
|
||||
return state.PushError("Remote connections are not allowed")
|
||||
}
|
||||
|
||||
// Use bytebufferpool for request and response
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
// Set up request
|
||||
req.Header.SetMethod(method)
|
||||
req.SetRequestURI(urlStr)
|
||||
req.Header.Set("User-Agent", "Moonshark/1.0")
|
||||
|
||||
// Get body (optional)
|
||||
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||
if state.IsString(3) {
|
||||
// String body
|
||||
bodyStr, _ := state.SafeToString(3)
|
||||
req.SetBodyString(bodyStr)
|
||||
} else if state.IsTable(3) {
|
||||
// Table body - convert to JSON
|
||||
luaTable, err := state.ToTable(3)
|
||||
if err != nil {
|
||||
return state.PushError("Failed to parse body table: %v", err)
|
||||
}
|
||||
|
||||
// Use bytebufferpool for JSON serialization
|
||||
buf := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buf)
|
||||
|
||||
if err := json.NewEncoder(buf).Encode(luaTable); err != nil {
|
||||
return state.PushError("Failed to convert body to JSON: %v", err)
|
||||
}
|
||||
|
||||
req.SetBody(buf.Bytes())
|
||||
req.Header.SetContentType("application/json")
|
||||
} else {
|
||||
return state.PushError("Body must be a string or table")
|
||||
}
|
||||
}
|
||||
|
||||
// Process options (headers, timeout, etc.)
|
||||
timeout := config.DefaultTimeout
|
||||
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsTable(4) {
|
||||
if headers, ok := state.GetFieldTable(4, "headers"); ok {
|
||||
if headerMap, ok := headers.(map[string]any); ok {
|
||||
for name, value := range headerMap {
|
||||
if valueStr, ok := value.(string); ok {
|
||||
req.Header.Set(name, valueStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get timeout
|
||||
if timeoutVal := state.GetFieldNumber(4, "timeout", 0); timeoutVal > 0 {
|
||||
requestTimeout := time.Duration(timeoutVal) * time.Second
|
||||
|
||||
// Apply max timeout if configured
|
||||
if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout {
|
||||
timeout = config.MaxTimeout
|
||||
} else {
|
||||
timeout = requestTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// Process query parameters
|
||||
if query, ok := state.GetFieldTable(4, "query"); ok {
|
||||
args := req.URI().QueryArgs()
|
||||
|
||||
if queryMap, ok := query.(map[string]any); ok {
|
||||
for name, value := range queryMap {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
args.Add(name, v)
|
||||
case int:
|
||||
args.Add(name, fmt.Sprintf("%d", v))
|
||||
case float64:
|
||||
args.Add(name, strings.TrimRight(strings.TrimRight(fmt.Sprintf("%.6f", v), "0"), "."))
|
||||
case bool:
|
||||
if v {
|
||||
args.Add(name, "true")
|
||||
} else {
|
||||
args.Add(name, "false")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create context with timeout
|
||||
_, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// Execute request
|
||||
err = defaultFastClient.DoTimeout(req, resp, timeout)
|
||||
if err != nil {
|
||||
errStr := "Request failed: " + err.Error()
|
||||
if errors.Is(err, fasthttp.ErrTimeout) {
|
||||
errStr = "Request timed out after " + timeout.String()
|
||||
}
|
||||
return state.PushError("%s", errStr)
|
||||
}
|
||||
|
||||
// Create response using TableBuilder
|
||||
builder := state.NewTableBuilder()
|
||||
|
||||
// Set status code and text
|
||||
builder.SetNumber("status", float64(resp.StatusCode()))
|
||||
builder.SetString("status_text", fasthttp.StatusMessage(resp.StatusCode()))
|
||||
|
||||
// Set body
|
||||
var respBody []byte
|
||||
if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize {
|
||||
// Make a limited copy
|
||||
respBody = make([]byte, config.MaxResponseSize)
|
||||
copy(respBody, resp.Body())
|
||||
} else {
|
||||
respBody = resp.Body()
|
||||
}
|
||||
|
||||
builder.SetString("body", string(respBody))
|
||||
|
||||
// Parse body as JSON if content type is application/json
|
||||
contentType := string(resp.Header.ContentType())
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
var jsonData any
|
||||
if err := json.Unmarshal(respBody, &jsonData); err == nil {
|
||||
builder.SetTable("json", jsonData)
|
||||
}
|
||||
}
|
||||
|
||||
// Set headers
|
||||
headers := make(map[string]string)
|
||||
resp.Header.VisitAll(func(key, value []byte) {
|
||||
headers[string(key)] = string(value)
|
||||
})
|
||||
builder.SetTable("headers", headers)
|
||||
|
||||
// Create ok field (true if status code is 2xx)
|
||||
builder.SetBool("ok", resp.StatusCode() >= 200 && resp.StatusCode() < 300)
|
||||
|
||||
builder.Build()
|
||||
return 1
|
||||
}
|
@ -1,97 +0,0 @@
|
||||
package lualibs
|
||||
|
||||
import (
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"github.com/alexedwards/argon2id"
|
||||
)
|
||||
|
||||
// RegisterPasswordFunctions registers password-related functions in the Lua state
|
||||
func RegisterPasswordFunctions(state *luajit.State) error {
|
||||
if err := state.RegisterGoFunction("__password_hash", passwordHash); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__password_verify", passwordVerify); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// passwordHash implements the Argon2id password hashing using alexedwards/argon2id
|
||||
func passwordHash(state *luajit.State) int {
|
||||
if err := state.CheckMinArgs(1); err != nil {
|
||||
return state.PushError("password_hash: %v", err)
|
||||
}
|
||||
|
||||
password, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("password_hash: password must be string")
|
||||
}
|
||||
|
||||
params := &argon2id.Params{
|
||||
Memory: 128 * 1024,
|
||||
Iterations: 4,
|
||||
Parallelism: 4,
|
||||
SaltLength: 16,
|
||||
KeyLength: 32,
|
||||
}
|
||||
|
||||
if state.GetTop() >= 2 && state.IsTable(2) {
|
||||
// Use new field getters with validation
|
||||
if memory := state.GetFieldNumber(2, "memory", 0); memory > 0 {
|
||||
params.Memory = max(uint32(memory), 8*1024)
|
||||
}
|
||||
|
||||
if iterations := state.GetFieldNumber(2, "iterations", 0); iterations > 0 {
|
||||
params.Iterations = max(uint32(iterations), 1)
|
||||
}
|
||||
|
||||
if parallelism := state.GetFieldNumber(2, "parallelism", 0); parallelism > 0 {
|
||||
params.Parallelism = max(uint8(parallelism), 1)
|
||||
}
|
||||
|
||||
if saltLength := state.GetFieldNumber(2, "salt_length", 0); saltLength > 0 {
|
||||
params.SaltLength = max(uint32(saltLength), 8)
|
||||
}
|
||||
|
||||
if keyLength := state.GetFieldNumber(2, "key_length", 0); keyLength > 0 {
|
||||
params.KeyLength = max(uint32(keyLength), 16)
|
||||
}
|
||||
}
|
||||
|
||||
hash, err := argon2id.CreateHash(password, params)
|
||||
if err != nil {
|
||||
return state.PushError("password_hash: %v", err)
|
||||
}
|
||||
|
||||
state.PushString(hash)
|
||||
return 1
|
||||
}
|
||||
|
||||
// passwordVerify verifies a password against a hash
|
||||
func passwordVerify(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(2); err != nil {
|
||||
state.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
password, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
state.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
hash, err := state.SafeToString(2)
|
||||
if err != nil {
|
||||
state.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
match, err := argon2id.ComparePasswordAndHash(password, hash)
|
||||
if err != nil {
|
||||
state.PushBoolean(false)
|
||||
return 1
|
||||
}
|
||||
|
||||
state.PushBoolean(match)
|
||||
return 1
|
||||
}
|
@ -1,192 +0,0 @@
|
||||
package lualibs
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"html"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// RegisterUtilFunctions registers utility functions with the Lua state
|
||||
func RegisterUtilFunctions(state *luajit.State) error {
|
||||
if err := state.RegisterGoFunction("__json_marshal", jsonMarshal); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := state.RegisterGoFunction("__json_unmarshal", jsonUnmarshal); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// HTML special chars
|
||||
if err := state.RegisterGoFunction("__html_special_chars", htmlSpecialChars); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// HTML entities
|
||||
if err := state.RegisterGoFunction("__html_entities", htmlEntities); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Base64 encode
|
||||
if err := state.RegisterGoFunction("__base64_encode", base64Encode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Base64 decode
|
||||
if err := state.RegisterGoFunction("__base64_decode", base64Decode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// htmlSpecialChars converts special characters to HTML entities
|
||||
func htmlSpecialChars(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(1); err != nil {
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
input, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
result := html.EscapeString(input)
|
||||
state.PushString(result)
|
||||
return 1
|
||||
}
|
||||
|
||||
// htmlEntities is a more comprehensive version of htmlSpecialChars
|
||||
func htmlEntities(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(1); err != nil {
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
input, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
// First use HTML escape for standard entities
|
||||
result := html.EscapeString(input)
|
||||
|
||||
// Additional entities beyond what html.EscapeString handles
|
||||
replacements := map[string]string{
|
||||
"©": "©",
|
||||
"®": "®",
|
||||
"™": "™",
|
||||
"€": "€",
|
||||
"£": "£",
|
||||
"¥": "¥",
|
||||
"—": "—",
|
||||
"–": "–",
|
||||
"…": "…",
|
||||
"•": "•",
|
||||
"°": "°",
|
||||
"±": "±",
|
||||
"¼": "¼",
|
||||
"½": "½",
|
||||
"¾": "¾",
|
||||
}
|
||||
|
||||
for char, entity := range replacements {
|
||||
result = strings.ReplaceAll(result, char, entity)
|
||||
}
|
||||
|
||||
state.PushString(result)
|
||||
return 1
|
||||
}
|
||||
|
||||
// base64Encode encodes a string to base64
|
||||
func base64Encode(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(1); err != nil {
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
input, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
result := base64.StdEncoding.EncodeToString([]byte(input))
|
||||
state.PushString(result)
|
||||
return 1
|
||||
}
|
||||
|
||||
// base64Decode decodes a base64 string
|
||||
func base64Decode(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(1); err != nil {
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
input, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
result, err := base64.StdEncoding.DecodeString(input)
|
||||
if err != nil {
|
||||
state.PushNil()
|
||||
return 1
|
||||
}
|
||||
|
||||
state.PushString(string(result))
|
||||
return 1
|
||||
}
|
||||
|
||||
// jsonMarshal converts a Lua value to a JSON string with validation
|
||||
func jsonMarshal(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(1); err != nil {
|
||||
return state.PushError("json marshal: %v", err)
|
||||
}
|
||||
|
||||
value, err := state.ToTable(1)
|
||||
if err != nil {
|
||||
// Try as generic value if not a table
|
||||
value, err = state.ToValue(1)
|
||||
if err != nil {
|
||||
return state.PushError("json marshal error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
bytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return state.PushError("json marshal error: %v", err)
|
||||
}
|
||||
|
||||
state.PushString(string(bytes))
|
||||
return 1
|
||||
}
|
||||
|
||||
// jsonUnmarshal converts a JSON string to a Lua value with validation
|
||||
func jsonUnmarshal(state *luajit.State) int {
|
||||
if err := state.CheckExactArgs(1); err != nil {
|
||||
return state.PushError("json unmarshal: %v", err)
|
||||
}
|
||||
|
||||
jsonStr, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("json unmarshal: expected string, got %s", state.GetType(1))
|
||||
}
|
||||
|
||||
var value any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &value); err != nil {
|
||||
return state.PushError("json unmarshal error: %v", err)
|
||||
}
|
||||
|
||||
if err := state.PushValue(value); err != nil {
|
||||
return state.PushError("json unmarshal error: %v", err)
|
||||
}
|
||||
return 1
|
||||
}
|
@ -1,292 +0,0 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"Moonshark/logger"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
type ModuleConfig struct {
|
||||
ScriptDir string
|
||||
LibDirs []string
|
||||
}
|
||||
|
||||
type ModuleLoader struct {
|
||||
config *ModuleConfig
|
||||
pathCache map[string]string // For reverse lookups (path -> module name)
|
||||
debug bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewModuleLoader(config *ModuleConfig) *ModuleLoader {
|
||||
if config == nil {
|
||||
config = &ModuleConfig{}
|
||||
}
|
||||
|
||||
return &ModuleLoader{
|
||||
config: config,
|
||||
pathCache: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ModuleLoader) EnableDebug() {
|
||||
l.debug = true
|
||||
}
|
||||
|
||||
func (l *ModuleLoader) SetScriptDir(dir string) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.config.ScriptDir = dir
|
||||
}
|
||||
|
||||
func (l *ModuleLoader) debugLog(format string, args ...any) {
|
||||
if l.debug {
|
||||
logger.Debugf("ModuleLoader "+format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ModuleLoader) SetupRequire(state *luajit.State) error {
|
||||
// Set package.path
|
||||
paths := l.getSearchPaths()
|
||||
pathStr := strings.Join(paths, ";")
|
||||
|
||||
return state.DoString(`package.path = "` + escapeLuaString(pathStr) + `"`)
|
||||
}
|
||||
|
||||
func (l *ModuleLoader) getSearchPaths() []string {
|
||||
var paths []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
// Script directory first
|
||||
if l.config.ScriptDir != "" {
|
||||
if absPath, err := filepath.Abs(l.config.ScriptDir); err == nil && !seen[absPath] {
|
||||
paths = append(paths, filepath.Join(absPath, "?.lua"))
|
||||
seen[absPath] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Library directories
|
||||
for _, dir := range l.config.LibDirs {
|
||||
if dir == "" {
|
||||
continue
|
||||
}
|
||||
if absPath, err := filepath.Abs(dir); err == nil && !seen[absPath] {
|
||||
paths = append(paths, filepath.Join(absPath, "?.lua"))
|
||||
seen[absPath] = true
|
||||
}
|
||||
}
|
||||
|
||||
return paths
|
||||
}
|
||||
|
||||
func (l *ModuleLoader) PreloadModules(state *luajit.State) error {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
// Reset caches
|
||||
l.pathCache = make(map[string]string)
|
||||
|
||||
// Clear non-core modules
|
||||
err := state.DoString(`
|
||||
local core = {string=1, table=1, math=1, os=1, package=1, io=1, coroutine=1, debug=1, _G=1}
|
||||
for name in pairs(package.loaded) do
|
||||
if not core[name] then package.loaded[name] = nil end
|
||||
end
|
||||
package.preload = {}
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Scan and preload modules
|
||||
for _, dir := range l.config.LibDirs {
|
||||
if err := l.scanDirectory(state, dir); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Install simplified require
|
||||
return state.DoString(`
|
||||
function __setup_require(env)
|
||||
env.require = function(modname)
|
||||
if package.loaded[modname] then
|
||||
return package.loaded[modname]
|
||||
end
|
||||
|
||||
local loader = package.preload[modname]
|
||||
if loader then
|
||||
setfenv(loader, env)
|
||||
local result = loader() or true
|
||||
package.loaded[modname] = result
|
||||
return result
|
||||
end
|
||||
|
||||
-- Standard path search
|
||||
for path in package.path:gmatch("[^;]+") do
|
||||
local file = path:gsub("?", modname:gsub("%.", "/"))
|
||||
local chunk = loadfile(file)
|
||||
if chunk then
|
||||
setfenv(chunk, env)
|
||||
local result = chunk() or true
|
||||
package.loaded[modname] = result
|
||||
return result
|
||||
end
|
||||
end
|
||||
|
||||
error("module '" .. modname .. "' not found", 2)
|
||||
end
|
||||
return env
|
||||
end
|
||||
`)
|
||||
}
|
||||
|
||||
func (l *ModuleLoader) scanDirectory(state *luajit.State, dir string) error {
|
||||
if dir == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
absDir, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
l.debugLog("Scanning directory: %s", absDir)
|
||||
|
||||
return filepath.Walk(absDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil || info.IsDir() || !strings.HasSuffix(path, ".lua") {
|
||||
return nil
|
||||
}
|
||||
|
||||
relPath, err := filepath.Rel(absDir, path)
|
||||
if err != nil || strings.HasPrefix(relPath, "..") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert to module name
|
||||
modName := strings.TrimSuffix(relPath, ".lua")
|
||||
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
|
||||
|
||||
l.debugLog("Found module: %s at %s", modName, path)
|
||||
l.pathCache[modName] = path
|
||||
|
||||
// Load and compile module
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
l.debugLog("Failed to read %s: %v", path, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := state.LoadString(string(content)); err != nil {
|
||||
l.debugLog("Failed to compile %s: %v", path, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Store in package.preload
|
||||
state.GetGlobal("package")
|
||||
state.GetField(-1, "preload")
|
||||
state.PushString(modName)
|
||||
state.PushCopy(-4) // Copy compiled function
|
||||
state.SetTable(-3)
|
||||
state.Pop(2) // Pop package and preload
|
||||
state.Pop(1) // Pop function
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (l *ModuleLoader) GetModuleByPath(path string) (string, bool) {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
absPath = filepath.Clean(path)
|
||||
}
|
||||
|
||||
// Direct lookup
|
||||
for modName, modPath := range l.pathCache {
|
||||
if modPath == absPath {
|
||||
return modName, true
|
||||
}
|
||||
}
|
||||
|
||||
// Construct from lib dirs
|
||||
for _, dir := range l.config.LibDirs {
|
||||
absDir, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
relPath, err := filepath.Rel(absDir, absPath)
|
||||
if err != nil || strings.HasPrefix(relPath, "..") || !strings.HasSuffix(relPath, ".lua") {
|
||||
continue
|
||||
}
|
||||
|
||||
modName := strings.TrimSuffix(relPath, ".lua")
|
||||
modName = strings.ReplaceAll(modName, string(filepath.Separator), ".")
|
||||
return modName, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (l *ModuleLoader) RefreshModule(state *luajit.State, moduleName string) error {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
path, exists := l.pathCache[moduleName]
|
||||
if !exists {
|
||||
return fmt.Errorf("module %s not found", moduleName)
|
||||
}
|
||||
|
||||
l.debugLog("Refreshing module: %s", moduleName)
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read module: %w", err)
|
||||
}
|
||||
|
||||
// Compile new version
|
||||
if err := state.LoadString(string(content)); err != nil {
|
||||
return fmt.Errorf("failed to compile module: %w", err)
|
||||
}
|
||||
|
||||
// Update package.preload
|
||||
state.GetGlobal("package")
|
||||
state.GetField(-1, "preload")
|
||||
state.PushString(moduleName)
|
||||
state.PushCopy(-4) // Copy function
|
||||
state.SetTable(-3)
|
||||
state.Pop(2) // Pop package and preload
|
||||
state.Pop(1) // Pop function
|
||||
|
||||
// Clear from loaded
|
||||
state.DoString(`package.loaded["` + escapeLuaString(moduleName) + `"] = nil`)
|
||||
|
||||
l.debugLog("Successfully refreshed: %s", moduleName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *ModuleLoader) RefreshModuleByPath(state *luajit.State, filePath string) error {
|
||||
moduleName, exists := l.GetModuleByPath(filePath)
|
||||
if !exists {
|
||||
return fmt.Errorf("no module found for path: %s", filePath)
|
||||
}
|
||||
return l.RefreshModule(state, moduleName)
|
||||
}
|
||||
|
||||
func escapeLuaString(s string) string {
|
||||
return strings.NewReplacer(
|
||||
`\`, `\\`,
|
||||
`"`, `\"`,
|
||||
"\n", `\n`,
|
||||
"\r", `\r`,
|
||||
"\t", `\t`,
|
||||
).Replace(s)
|
||||
}
|
@ -1,56 +0,0 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// Response represents a unified response from script execution
|
||||
type Response struct {
|
||||
// Basic properties
|
||||
Body any // Body content (any type)
|
||||
Metadata map[string]any // Additional metadata
|
||||
|
||||
// HTTP specific properties
|
||||
Status int // HTTP status code
|
||||
Headers map[string]string // HTTP headers
|
||||
Cookies []*fasthttp.Cookie // HTTP cookies
|
||||
|
||||
// Session information
|
||||
SessionData map[string]any
|
||||
}
|
||||
|
||||
// Response pool to reduce allocations
|
||||
var responsePool = sync.Pool{
|
||||
New: func() any {
|
||||
return &Response{
|
||||
Status: 200,
|
||||
Headers: make(map[string]string, 8),
|
||||
Metadata: make(map[string]any, 8),
|
||||
Cookies: make([]*fasthttp.Cookie, 0, 4),
|
||||
SessionData: make(map[string]any, 8),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// NewResponse creates a new response object from the pool
|
||||
func NewResponse() *Response {
|
||||
return responsePool.Get().(*Response)
|
||||
}
|
||||
|
||||
// Release returns a response to the pool after cleaning it
|
||||
func ReleaseResponse(resp *Response) {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp.Body = nil
|
||||
resp.Status = 200
|
||||
resp.Headers = make(map[string]string, 8)
|
||||
resp.Metadata = make(map[string]any, 8)
|
||||
resp.Cookies = resp.Cookies[:0]
|
||||
resp.SessionData = make(map[string]any, 8)
|
||||
|
||||
responsePool.Put(resp)
|
||||
}
|
335
runner/runner.go
335
runner/runner.go
@ -1,335 +0,0 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"Moonshark/config"
|
||||
"Moonshark/logger"
|
||||
"Moonshark/runner/lualibs"
|
||||
"Moonshark/runner/sqlite"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
var emptyMap = make(map[string]any)
|
||||
|
||||
var (
|
||||
ErrRunnerClosed = errors.New("lua runner is closed")
|
||||
ErrTimeout = errors.New("operation timed out")
|
||||
ErrStateNotReady = errors.New("lua state not ready")
|
||||
)
|
||||
|
||||
type State struct {
|
||||
L *luajit.State
|
||||
sandbox *Sandbox
|
||||
index int
|
||||
inUse atomic.Bool
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
states []*State
|
||||
statePool chan int
|
||||
poolSize int
|
||||
moduleLoader *ModuleLoader
|
||||
isRunning atomic.Bool
|
||||
mu sync.RWMutex
|
||||
scriptDir string
|
||||
|
||||
// Pre-allocated pools for HTTP processing
|
||||
ctxPool sync.Pool
|
||||
paramsPool sync.Pool
|
||||
}
|
||||
|
||||
func NewRunner(cfg *config.Config, poolSize int) (*Runner, error) {
|
||||
if poolSize <= 0 && cfg.Runner.PoolSize <= 0 {
|
||||
poolSize = runtime.GOMAXPROCS(0)
|
||||
}
|
||||
|
||||
moduleConfig := &ModuleConfig{
|
||||
LibDirs: cfg.Dirs.Libs,
|
||||
}
|
||||
|
||||
r := &Runner{
|
||||
poolSize: poolSize,
|
||||
moduleLoader: NewModuleLoader(moduleConfig),
|
||||
ctxPool: sync.Pool{
|
||||
New: func() any { return make(map[string]any, 8) },
|
||||
},
|
||||
paramsPool: sync.Pool{
|
||||
New: func() any { return make(map[string]any, 4) },
|
||||
},
|
||||
}
|
||||
|
||||
sqlite.InitSQLite(cfg.Dirs.Data)
|
||||
sqlite.SetSQLitePoolSize(poolSize)
|
||||
lualibs.InitFS(cfg.Dirs.FS)
|
||||
lualibs.InitEnv(cfg.Dirs.Data)
|
||||
|
||||
r.states = make([]*State, poolSize)
|
||||
r.statePool = make(chan int, poolSize)
|
||||
|
||||
if err := r.initStates(); err != nil {
|
||||
sqlite.CleanupSQLite()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.isRunning.Store(true)
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *Runner) Execute(bytecode []byte, ctx ExecutionContext) (*Response, error) {
|
||||
if !r.isRunning.Load() {
|
||||
return nil, ErrRunnerClosed
|
||||
}
|
||||
|
||||
var stateIndex int
|
||||
select {
|
||||
case stateIndex = <-r.statePool:
|
||||
case <-time.After(time.Second):
|
||||
return nil, ErrTimeout
|
||||
}
|
||||
|
||||
state := r.states[stateIndex]
|
||||
if state == nil {
|
||||
r.statePool <- stateIndex
|
||||
return nil, ErrStateNotReady
|
||||
}
|
||||
|
||||
state.inUse.Store(true)
|
||||
defer func() {
|
||||
state.inUse.Store(false)
|
||||
if r.isRunning.Load() {
|
||||
select {
|
||||
case r.statePool <- stateIndex:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return state.sandbox.Execute(state.L, bytecode, ctx, state.index)
|
||||
}
|
||||
|
||||
func (r *Runner) initStates() error {
|
||||
logger.Infof("[LuaRunner] Creating %d states...", r.poolSize)
|
||||
|
||||
for i := range r.poolSize {
|
||||
state, err := r.createState(i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.states[i] = state
|
||||
r.statePool <- i
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) createState(index int) (*State, error) {
|
||||
L := luajit.New(true)
|
||||
if L == nil {
|
||||
return nil, errors.New("failed to create Lua state")
|
||||
}
|
||||
|
||||
sb := NewSandbox()
|
||||
if err := sb.Setup(L, index, index == 0); err != nil {
|
||||
L.Cleanup()
|
||||
L.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := r.moduleLoader.SetupRequire(L); err != nil {
|
||||
L.Cleanup()
|
||||
L.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := r.moduleLoader.PreloadModules(L); err != nil {
|
||||
L.Cleanup()
|
||||
L.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &State{L: L, sandbox: sb, index: index}, nil
|
||||
}
|
||||
|
||||
func (r *Runner) Close() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if !r.isRunning.Load() {
|
||||
return ErrRunnerClosed
|
||||
}
|
||||
r.isRunning.Store(false)
|
||||
|
||||
// Drain pool
|
||||
for {
|
||||
select {
|
||||
case <-r.statePool:
|
||||
default:
|
||||
goto cleanup
|
||||
}
|
||||
}
|
||||
|
||||
cleanup:
|
||||
// Wait for states to finish
|
||||
timeout := time.Now().Add(10 * time.Second)
|
||||
for time.Now().Before(timeout) {
|
||||
allIdle := true
|
||||
for _, state := range r.states {
|
||||
if state != nil && state.inUse.Load() {
|
||||
allIdle = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allIdle {
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Close states
|
||||
for i, state := range r.states {
|
||||
if state != nil {
|
||||
state.L.Cleanup()
|
||||
state.L.Close()
|
||||
r.states[i] = nil
|
||||
}
|
||||
}
|
||||
|
||||
lualibs.CleanupFS()
|
||||
sqlite.CleanupSQLite()
|
||||
lualibs.CleanupEnv()
|
||||
return nil
|
||||
}
|
||||
|
||||
// NotifyFileChanged alerts the runner about file changes
|
||||
func (r *Runner) NotifyFileChanged(filePath string) bool {
|
||||
logger.Debugf("Runner notified of file change: %s", filePath)
|
||||
|
||||
module, isModule := r.moduleLoader.GetModuleByPath(filePath)
|
||||
if isModule {
|
||||
logger.Debugf("Refreshing module: %s", module)
|
||||
return r.RefreshModule(module)
|
||||
}
|
||||
|
||||
logger.Debugf("File change noted but no refresh needed: %s", filePath)
|
||||
return true
|
||||
}
|
||||
|
||||
// RefreshModule refreshes a specific module across all states
|
||||
func (r *Runner) RefreshModule(moduleName string) bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
if !r.isRunning.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
logger.Debugf("Refreshing module: %s", moduleName)
|
||||
|
||||
success := true
|
||||
for _, state := range r.states {
|
||||
if state == nil || state.inUse.Load() {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := r.moduleLoader.RefreshModule(state.L, moduleName); err != nil {
|
||||
success = false
|
||||
logger.Debugf("Failed to refresh module %s in state %d: %v", moduleName, state.index, err)
|
||||
}
|
||||
}
|
||||
|
||||
if success {
|
||||
logger.Debugf("Successfully refreshed module: %s", moduleName)
|
||||
}
|
||||
|
||||
return success
|
||||
}
|
||||
|
||||
// RunScriptFile loads, compiles and executes a Lua script file
|
||||
func (r *Runner) RunScriptFile(filePath string) (*Response, error) {
|
||||
if !r.isRunning.Load() {
|
||||
return nil, ErrRunnerClosed
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("script file not found: %s", filePath)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get absolute path: %w", err)
|
||||
}
|
||||
scriptDir := filepath.Dir(absPath)
|
||||
|
||||
r.mu.Lock()
|
||||
prevScriptDir := r.scriptDir
|
||||
r.scriptDir = scriptDir
|
||||
r.moduleLoader.SetScriptDir(scriptDir)
|
||||
r.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
r.mu.Lock()
|
||||
r.scriptDir = prevScriptDir
|
||||
r.moduleLoader.SetScriptDir(prevScriptDir)
|
||||
r.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Get state from pool
|
||||
var stateIndex int
|
||||
select {
|
||||
case stateIndex = <-r.statePool:
|
||||
case <-time.After(5 * time.Second):
|
||||
return nil, ErrTimeout
|
||||
}
|
||||
|
||||
state := r.states[stateIndex]
|
||||
if state == nil {
|
||||
r.statePool <- stateIndex
|
||||
return nil, ErrStateNotReady
|
||||
}
|
||||
|
||||
state.inUse.Store(true)
|
||||
|
||||
defer func() {
|
||||
state.inUse.Store(false)
|
||||
if r.isRunning.Load() {
|
||||
select {
|
||||
case r.statePool <- stateIndex:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Compile script
|
||||
bytecode, err := state.L.CompileBytecode(string(content), filepath.Base(absPath))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compilation error: %w", err)
|
||||
}
|
||||
|
||||
// Create simple context for script execution
|
||||
ctx := NewContext()
|
||||
defer ctx.Release()
|
||||
|
||||
ctx.Set("_script_path", absPath)
|
||||
ctx.Set("_script_dir", scriptDir)
|
||||
|
||||
// Execute script
|
||||
response, err := state.sandbox.Execute(state.L, bytecode, ctx, state.index)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("execution error: %w", err)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
@ -1,220 +0,0 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"Moonshark/runner/lualibs"
|
||||
"Moonshark/runner/sqlite"
|
||||
"fmt"
|
||||
|
||||
"maps"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// Sandbox provides a secure execution environment for Lua scripts
|
||||
type Sandbox struct {
|
||||
executorBytecode []byte
|
||||
}
|
||||
|
||||
// NewSandbox creates a new sandbox environment
|
||||
func NewSandbox() *Sandbox {
|
||||
return &Sandbox{}
|
||||
}
|
||||
|
||||
// Setup initializes the sandbox in a Lua state
|
||||
func (s *Sandbox) Setup(state *luajit.State, stateIndex int, verbose bool) error {
|
||||
// Load all embedded modules and sandbox
|
||||
if err := loadSandboxIntoState(state, verbose); err != nil {
|
||||
return fmt.Errorf("failed to load sandbox: %w", err)
|
||||
}
|
||||
|
||||
// Set the state index as a global variable
|
||||
state.PushNumber(float64(stateIndex))
|
||||
state.SetGlobal("__STATE_INDEX")
|
||||
|
||||
// Pre-compile the executor function for reuse
|
||||
executorCode := `return __execute`
|
||||
bytecode, err := state.CompileBytecode(executorCode, "executor")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to compile executor: %w", err)
|
||||
}
|
||||
s.executorBytecode = bytecode
|
||||
|
||||
// Register native functions
|
||||
if err := s.registerCoreFunctions(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute runs a Lua script in the sandbox with the given context
|
||||
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx ExecutionContext, stateIndex int) (*Response, error) {
|
||||
// Load script and executor
|
||||
if err := state.LoadBytecode(bytecode, "script"); err != nil {
|
||||
return nil, fmt.Errorf("failed to load bytecode: %w", err)
|
||||
}
|
||||
|
||||
if err := state.LoadBytecode(s.executorBytecode, "executor"); err != nil {
|
||||
state.Pop(1)
|
||||
return nil, fmt.Errorf("failed to load executor: %w", err)
|
||||
}
|
||||
|
||||
// Get __execute function
|
||||
if err := state.Call(0, 1); err != nil {
|
||||
state.Pop(1)
|
||||
return nil, fmt.Errorf("failed to get executor: %w", err)
|
||||
}
|
||||
|
||||
// Prepare response object
|
||||
response := map[string]any{
|
||||
"status": 200,
|
||||
"headers": make(map[string]string),
|
||||
"cookies": []any{},
|
||||
"metadata": make(map[string]any),
|
||||
"session": make(map[string]any),
|
||||
"flash": make(map[string]any),
|
||||
}
|
||||
|
||||
// Call __execute(script_func, ctx, response)
|
||||
state.PushCopy(-2) // script function
|
||||
state.PushValue(ctx.ToMap())
|
||||
state.PushValue(response)
|
||||
|
||||
if err := state.Call(3, 1); err != nil {
|
||||
state.Pop(1)
|
||||
return nil, fmt.Errorf("script execution failed: %w", err)
|
||||
}
|
||||
|
||||
// Extract result
|
||||
result, _ := state.ToValue(-1)
|
||||
state.Pop(2) // Clean up
|
||||
|
||||
sqlite.CleanupStateConnection(stateIndex)
|
||||
|
||||
var modifiedResponse map[string]any
|
||||
var scriptResult any
|
||||
|
||||
if arr, ok := result.([]any); ok && len(arr) >= 2 {
|
||||
scriptResult = arr[0]
|
||||
if resp, ok := arr[1].(map[string]any); ok {
|
||||
modifiedResponse = resp
|
||||
}
|
||||
}
|
||||
|
||||
if modifiedResponse == nil {
|
||||
scriptResult = result
|
||||
modifiedResponse = response
|
||||
}
|
||||
|
||||
return s.buildResponse(modifiedResponse, scriptResult), nil
|
||||
}
|
||||
|
||||
// buildResponse converts the Lua response object to a Go Response
|
||||
func (s *Sandbox) buildResponse(luaResp map[string]any, body any) *Response {
|
||||
resp := NewResponse()
|
||||
resp.Body = body
|
||||
|
||||
// Extract status
|
||||
if status, ok := luaResp["status"].(float64); ok {
|
||||
resp.Status = int(status)
|
||||
} else if status, ok := luaResp["status"].(int); ok {
|
||||
resp.Status = status
|
||||
}
|
||||
|
||||
// Extract headers
|
||||
if headers, ok := luaResp["headers"].(map[string]any); ok {
|
||||
for k, v := range headers {
|
||||
if str, ok := v.(string); ok {
|
||||
resp.Headers[k] = str
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract cookies
|
||||
if cookies, ok := luaResp["cookies"].([]any); ok {
|
||||
for _, cookieData := range cookies {
|
||||
if cookieMap, ok := cookieData.(map[string]any); ok {
|
||||
cookie := fasthttp.AcquireCookie()
|
||||
|
||||
if name, ok := cookieMap["name"].(string); ok && name != "" {
|
||||
cookie.SetKey(name)
|
||||
if value, ok := cookieMap["value"].(string); ok {
|
||||
cookie.SetValue(value)
|
||||
}
|
||||
if path, ok := cookieMap["path"].(string); ok {
|
||||
cookie.SetPath(path)
|
||||
}
|
||||
if domain, ok := cookieMap["domain"].(string); ok {
|
||||
cookie.SetDomain(domain)
|
||||
}
|
||||
if httpOnly, ok := cookieMap["http_only"].(bool); ok {
|
||||
cookie.SetHTTPOnly(httpOnly)
|
||||
}
|
||||
if secure, ok := cookieMap["secure"].(bool); ok {
|
||||
cookie.SetSecure(secure)
|
||||
}
|
||||
if maxAge, ok := cookieMap["max_age"].(float64); ok {
|
||||
cookie.SetMaxAge(int(maxAge))
|
||||
} else if maxAge, ok := cookieMap["max_age"].(int); ok {
|
||||
cookie.SetMaxAge(maxAge)
|
||||
}
|
||||
|
||||
resp.Cookies = append(resp.Cookies, cookie)
|
||||
} else {
|
||||
fasthttp.ReleaseCookie(cookie)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract metadata - simplified
|
||||
if metadata, ok := luaResp["metadata"].(map[string]any); ok {
|
||||
maps.Copy(resp.Metadata, metadata)
|
||||
}
|
||||
|
||||
// Extract session data - simplified
|
||||
if session, ok := luaResp["session"].(map[string]any); ok {
|
||||
maps.Copy(resp.SessionData, session)
|
||||
}
|
||||
|
||||
// Extract flash data and add to metadata for processing by server
|
||||
if flash, ok := luaResp["flash"].(map[string]any); ok && len(flash) > 0 {
|
||||
resp.Metadata["flash"] = flash
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// registerCoreFunctions registers all built-in functions in the Lua state
|
||||
func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
|
||||
if err := lualibs.RegisterCryptoFunctions(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := lualibs.RegisterEnvFunctions(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := lualibs.RegisterFSFunctions(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := lualibs.RegisterHttpFunctions(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := lualibs.RegisterPasswordFunctions(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := sqlite.RegisterSQLiteFunctions(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := lualibs.RegisterUtilFunctions(state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -1,466 +0,0 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
sqlite "zombiezen.com/go/sqlite"
|
||||
"zombiezen.com/go/sqlite/sqlitex"
|
||||
|
||||
"Moonshark/logger"
|
||||
|
||||
"git.sharkk.net/Go/Color"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
var (
|
||||
dbPools = make(map[string]*sqlitex.Pool)
|
||||
poolsMu sync.RWMutex
|
||||
dataDir string
|
||||
poolSize = 8
|
||||
connTimeout = 5 * time.Second
|
||||
|
||||
// Per-state connection cache
|
||||
stateConns = make(map[string]*stateConn)
|
||||
stateConnsMu sync.RWMutex
|
||||
)
|
||||
|
||||
// stateConn tracks a connection and its origin pool
|
||||
type stateConn struct {
|
||||
conn *sqlite.Conn
|
||||
pool *sqlitex.Pool
|
||||
}
|
||||
|
||||
func InitSQLite(dir string) {
|
||||
dataDir = dir
|
||||
logger.Infof("SQLite is g2g! %s", color.Yellow(dir))
|
||||
}
|
||||
|
||||
func SetSQLitePoolSize(size int) {
|
||||
if size > 0 {
|
||||
poolSize = size
|
||||
}
|
||||
}
|
||||
|
||||
func CleanupSQLite() {
|
||||
poolsMu.Lock()
|
||||
defer poolsMu.Unlock()
|
||||
|
||||
// Return all cached connections to their pools
|
||||
stateConnsMu.Lock()
|
||||
for _, sc := range stateConns {
|
||||
if sc.pool != nil && sc.conn != nil {
|
||||
sc.pool.Put(sc.conn)
|
||||
}
|
||||
}
|
||||
stateConns = make(map[string]*stateConn)
|
||||
stateConnsMu.Unlock()
|
||||
|
||||
for name, pool := range dbPools {
|
||||
if err := pool.Close(); err != nil {
|
||||
logger.Errorf("Failed to close database %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
dbPools = make(map[string]*sqlitex.Pool)
|
||||
logger.Debugf("SQLite connections closed")
|
||||
}
|
||||
|
||||
func getPool(dbName string) (*sqlitex.Pool, error) {
|
||||
dbName = filepath.Base(dbName)
|
||||
if dbName == "" || dbName[0] == '.' {
|
||||
return nil, fmt.Errorf("invalid database name")
|
||||
}
|
||||
|
||||
poolsMu.RLock()
|
||||
pool, exists := dbPools[dbName]
|
||||
if exists {
|
||||
poolsMu.RUnlock()
|
||||
return pool, nil
|
||||
}
|
||||
poolsMu.RUnlock()
|
||||
|
||||
poolsMu.Lock()
|
||||
defer poolsMu.Unlock()
|
||||
|
||||
if pool, exists = dbPools[dbName]; exists {
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
dbPath := filepath.Join(dataDir, dbName+".db")
|
||||
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
|
||||
PoolSize: poolSize,
|
||||
PrepareConn: func(conn *sqlite.Conn) error {
|
||||
pragmas := []string{
|
||||
"PRAGMA journal_mode = WAL",
|
||||
"PRAGMA synchronous = NORMAL",
|
||||
"PRAGMA cache_size = 1000",
|
||||
"PRAGMA foreign_keys = ON",
|
||||
"PRAGMA temp_store = MEMORY",
|
||||
}
|
||||
for _, pragma := range pragmas {
|
||||
if err := sqlitex.ExecuteTransient(conn, pragma, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
dbPools[dbName] = pool
|
||||
logger.Debugf("Created SQLite pool for %s (size: %d)", dbName, poolSize)
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// getStateConnection gets or creates a reusable connection for the state+db
|
||||
func getStateConnection(stateIndex int, dbName string) (*sqlite.Conn, error) {
|
||||
connKey := fmt.Sprintf("%d-%s", stateIndex, dbName)
|
||||
|
||||
stateConnsMu.RLock()
|
||||
sc, exists := stateConns[connKey]
|
||||
stateConnsMu.RUnlock()
|
||||
|
||||
if exists && sc.conn != nil {
|
||||
return sc.conn, nil
|
||||
}
|
||||
|
||||
// Get new connection from pool
|
||||
pool, err := getPool(dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pool.Take(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connection timeout: %w", err)
|
||||
}
|
||||
|
||||
// Cache it with pool reference
|
||||
stateConnsMu.Lock()
|
||||
stateConns[connKey] = &stateConn{
|
||||
conn: conn,
|
||||
pool: pool,
|
||||
}
|
||||
stateConnsMu.Unlock()
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func sqlQuery(state *luajit.State) int {
|
||||
if err := state.CheckMinArgs(3); err != nil {
|
||||
return state.PushError("sqlite.query: %v", err)
|
||||
}
|
||||
|
||||
dbName, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("sqlite.query: database name must be string")
|
||||
}
|
||||
|
||||
query, err := state.SafeToString(2)
|
||||
if err != nil {
|
||||
return state.PushError("sqlite.query: query must be string")
|
||||
}
|
||||
|
||||
stateIndex := int(state.ToNumber(-1))
|
||||
|
||||
conn, err := getStateConnection(stateIndex, dbName)
|
||||
if err != nil {
|
||||
return state.PushError("sqlite.query: %v", err)
|
||||
}
|
||||
|
||||
var execOpts sqlitex.ExecOptions
|
||||
rows := make([]any, 0, 16)
|
||||
|
||||
if state.GetTop() >= 4 && !state.IsNil(3) {
|
||||
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||
return state.PushError("sqlite.query: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
||||
row := make(map[string]any)
|
||||
colCount := stmt.ColumnCount()
|
||||
|
||||
for i := range colCount {
|
||||
colName := stmt.ColumnName(i)
|
||||
switch stmt.ColumnType(i) {
|
||||
case sqlite.TypeInteger:
|
||||
row[colName] = stmt.ColumnInt64(i)
|
||||
case sqlite.TypeFloat:
|
||||
row[colName] = stmt.ColumnFloat(i)
|
||||
case sqlite.TypeText:
|
||||
row[colName] = stmt.ColumnText(i)
|
||||
case sqlite.TypeBlob:
|
||||
blobSize := stmt.ColumnLen(i)
|
||||
if blobSize > 0 {
|
||||
buf := make([]byte, blobSize)
|
||||
row[colName] = stmt.ColumnBytes(i, buf)
|
||||
} else {
|
||||
row[colName] = []byte{}
|
||||
}
|
||||
case sqlite.TypeNull:
|
||||
row[colName] = nil
|
||||
}
|
||||
}
|
||||
rows = append(rows, row)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||
return state.PushError("sqlite.query: %v", err)
|
||||
}
|
||||
|
||||
if err := state.PushValue(rows); err != nil {
|
||||
return state.PushError("sqlite.query: %v", err)
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func sqlExec(state *luajit.State) int {
|
||||
if err := state.CheckMinArgs(3); err != nil {
|
||||
return state.PushError("sqlite.exec: %v", err)
|
||||
}
|
||||
|
||||
dbName, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("sqlite.exec: database name must be string")
|
||||
}
|
||||
|
||||
query, err := state.SafeToString(2)
|
||||
if err != nil {
|
||||
return state.PushError("sqlite.exec: query must be string")
|
||||
}
|
||||
|
||||
stateIndex := int(state.ToNumber(-1))
|
||||
|
||||
conn, err := getStateConnection(stateIndex, dbName)
|
||||
if err != nil {
|
||||
return state.PushError("sqlite.exec: %v", err)
|
||||
}
|
||||
|
||||
hasParams := state.GetTop() >= 4 && !state.IsNil(3)
|
||||
|
||||
if strings.Contains(query, ";") && !hasParams {
|
||||
if err := sqlitex.ExecScript(conn, query); err != nil {
|
||||
return state.PushError("sqlite.exec: %v", err)
|
||||
}
|
||||
state.PushNumber(float64(conn.Changes()))
|
||||
return 1
|
||||
}
|
||||
|
||||
if !hasParams {
|
||||
if err := sqlitex.Execute(conn, query, nil); err != nil {
|
||||
return state.PushError("sqlite.exec: %v", err)
|
||||
}
|
||||
state.PushNumber(float64(conn.Changes()))
|
||||
return 1
|
||||
}
|
||||
|
||||
var execOpts sqlitex.ExecOptions
|
||||
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||
return state.PushError("sqlite.exec: %v", err)
|
||||
}
|
||||
|
||||
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||
return state.PushError("sqlite.exec: %v", err)
|
||||
}
|
||||
|
||||
state.PushNumber(float64(conn.Changes()))
|
||||
return 1
|
||||
}
|
||||
|
||||
func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error {
|
||||
if state.IsTable(paramIndex) {
|
||||
paramsAny, err := state.ToTable(paramIndex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid parameters: %w", err)
|
||||
}
|
||||
|
||||
// Handle direct array types
|
||||
if arrParams, ok := paramsAny.([]any); ok {
|
||||
execOpts.Args = arrParams
|
||||
return nil
|
||||
}
|
||||
if strArr, ok := paramsAny.([]string); ok {
|
||||
args := make([]any, len(strArr))
|
||||
for i, v := range strArr {
|
||||
args[i] = v
|
||||
}
|
||||
execOpts.Args = args
|
||||
return nil
|
||||
}
|
||||
if floatArr, ok := paramsAny.([]float64); ok {
|
||||
args := make([]any, len(floatArr))
|
||||
for i, v := range floatArr {
|
||||
args[i] = v
|
||||
}
|
||||
execOpts.Args = args
|
||||
return nil
|
||||
}
|
||||
|
||||
params, ok := paramsAny.(map[string]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported parameter type: %T", paramsAny)
|
||||
}
|
||||
|
||||
// Check for array-style parameters (empty string key indicates array)
|
||||
if arr, ok := params[""]; ok {
|
||||
if arrParams, ok := arr.([]any); ok {
|
||||
execOpts.Args = arrParams
|
||||
} else if floatArr, ok := arr.([]float64); ok {
|
||||
args := make([]any, len(floatArr))
|
||||
for i, v := range floatArr {
|
||||
args[i] = v
|
||||
}
|
||||
execOpts.Args = args
|
||||
}
|
||||
} else {
|
||||
// Named parameters
|
||||
named := make(map[string]any, len(params))
|
||||
for k, v := range params {
|
||||
if len(k) > 0 && k[0] != ':' {
|
||||
named[":"+k] = v
|
||||
} else {
|
||||
named[k] = v
|
||||
}
|
||||
}
|
||||
execOpts.Named = named
|
||||
}
|
||||
} else {
|
||||
// Multiple individual parameters
|
||||
count := state.GetTop() - 2
|
||||
args := make([]any, count)
|
||||
for i := range count {
|
||||
idx := i + 3
|
||||
val, err := state.ToValue(idx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid parameter %d: %w", i+1, err)
|
||||
}
|
||||
args[i] = val
|
||||
}
|
||||
execOpts.Args = args
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqlGetOne(state *luajit.State) int {
|
||||
if err := state.CheckMinArgs(3); err != nil {
|
||||
return state.PushError("sqlite.get_one: %v", err)
|
||||
}
|
||||
|
||||
dbName, err := state.SafeToString(1)
|
||||
if err != nil {
|
||||
return state.PushError("sqlite.get_one: database name must be string")
|
||||
}
|
||||
|
||||
query, err := state.SafeToString(2)
|
||||
if err != nil {
|
||||
return state.PushError("sqlite.get_one: query must be string")
|
||||
}
|
||||
|
||||
stateIndex := int(state.ToNumber(-1))
|
||||
|
||||
conn, err := getStateConnection(stateIndex, dbName)
|
||||
if err != nil {
|
||||
return state.PushError("sqlite.get_one: %v", err)
|
||||
}
|
||||
|
||||
var execOpts sqlitex.ExecOptions
|
||||
var result map[string]any
|
||||
|
||||
// Check if params provided (before state index)
|
||||
if state.GetTop() >= 4 && !state.IsNil(3) {
|
||||
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||
return state.PushError("sqlite.get_one: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
||||
if result != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result = make(map[string]any)
|
||||
colCount := stmt.ColumnCount()
|
||||
|
||||
for i := range colCount {
|
||||
colName := stmt.ColumnName(i)
|
||||
switch stmt.ColumnType(i) {
|
||||
case sqlite.TypeInteger:
|
||||
result[colName] = stmt.ColumnInt64(i)
|
||||
case sqlite.TypeFloat:
|
||||
result[colName] = stmt.ColumnFloat(i)
|
||||
case sqlite.TypeText:
|
||||
result[colName] = stmt.ColumnText(i)
|
||||
case sqlite.TypeBlob:
|
||||
blobSize := stmt.ColumnLen(i)
|
||||
if blobSize > 0 {
|
||||
buf := make([]byte, blobSize)
|
||||
result[colName] = stmt.ColumnBytes(i, buf)
|
||||
} else {
|
||||
result[colName] = []byte{}
|
||||
}
|
||||
case sqlite.TypeNull:
|
||||
result[colName] = nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||
return state.PushError("sqlite.get_one: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
state.PushNil()
|
||||
} else {
|
||||
if err := state.PushValue(result); err != nil {
|
||||
return state.PushError("sqlite.get_one: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// CleanupStateConnection releases all connections for a specific state
|
||||
func CleanupStateConnection(stateIndex int) {
|
||||
stateConnsMu.Lock()
|
||||
defer stateConnsMu.Unlock()
|
||||
|
||||
statePrefix := fmt.Sprintf("%d-", stateIndex)
|
||||
|
||||
for key, sc := range stateConns {
|
||||
if strings.HasPrefix(key, statePrefix) {
|
||||
if sc.pool != nil && sc.conn != nil {
|
||||
sc.pool.Put(sc.conn)
|
||||
}
|
||||
delete(stateConns, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RegisterSQLiteFunctions(state *luajit.State) error {
|
||||
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__sqlite_get_one", sqlGetOne); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
189
test.lua
Normal file
189
test.lua
Normal file
@ -0,0 +1,189 @@
|
||||
-- Example HTTP server with string-based routing, parameters, and wildcards
|
||||
|
||||
print("Starting Moonshark HTTP server with string routing...")
|
||||
|
||||
-- Start HTTP server
|
||||
http.listen(3000)
|
||||
|
||||
-- Home page
|
||||
http.route("GET", "/", function(req)
|
||||
local visits = session.get("visits") or 0
|
||||
visits = visits + 1
|
||||
session.set("visits", visits)
|
||||
|
||||
return http.html([[
|
||||
<h1>Welcome to Moonshark!</h1>
|
||||
<p>You've visited this page ]] .. visits .. [[ times.</p>
|
||||
<p><a href="/login">Login</a> | <a href="/users/123">User Profile</a></p>
|
||||
<p><a href="/api/test">API Test</a> | <a href="/files/docs/readme.txt">File Access</a></p>
|
||||
]])
|
||||
end)
|
||||
|
||||
-- User profile with dynamic parameter
|
||||
http.route("GET", "/users/:id", function(req)
|
||||
local userId = req.params.id
|
||||
return http.html([[
|
||||
<h2>User Profile</h2>
|
||||
<p>User ID: ]] .. userId .. [[</p>
|
||||
<p><a href="/users/]] .. userId .. [[/posts">View Posts</a></p>
|
||||
<p><a href="/">Home</a></p>
|
||||
]])
|
||||
end)
|
||||
|
||||
-- User posts with multiple parameters
|
||||
http.route("GET", "/users/:id/posts", function(req)
|
||||
local userId = req.params.id
|
||||
return http.json({
|
||||
user_id = userId,
|
||||
posts = {
|
||||
{id = 1, title = "First Post", content = "Hello world!"},
|
||||
{id = 2, title = "Second Post", content = "Learning Lua routing!"}
|
||||
}
|
||||
})
|
||||
end)
|
||||
|
||||
-- Blog post with slug parameter
|
||||
http.route("GET", "/blog/:slug", function(req)
|
||||
local slug = req.params.slug
|
||||
return http.html([[
|
||||
<h1>Blog Post: ]] .. slug .. [[</h1>
|
||||
<p>This is the content for blog post "]] .. slug .. [["</p>
|
||||
<p><a href="/blog/]] .. slug .. [[/comments">View Comments</a></p>
|
||||
<p><a href="/">Home</a></p>
|
||||
]])
|
||||
end)
|
||||
|
||||
-- Blog comments
|
||||
http.route("GET", "/blog/:slug/comments", function(req)
|
||||
local slug = req.params.slug
|
||||
return http.json({
|
||||
blog_slug = slug,
|
||||
comments = {
|
||||
{author = "Alice", comment = "Great post!"},
|
||||
{author = "Bob", comment = "Very informative."}
|
||||
}
|
||||
})
|
||||
end)
|
||||
|
||||
-- Wildcard route for file serving
|
||||
http.route("GET", "/files/*path", function(req)
|
||||
local filePath = req.params.path
|
||||
return http.html([[
|
||||
<h2>File Access</h2>
|
||||
<p>Requested file: ]] .. filePath .. [[</p>
|
||||
<p>In a real application, this would serve the file content.</p>
|
||||
<p><a href="/">Home</a></p>
|
||||
]])
|
||||
end)
|
||||
|
||||
-- API endpoints with parameters
|
||||
http.route("GET", "/api/users/:id", function(req)
|
||||
local userId = req.params.id
|
||||
return http.json({
|
||||
id = tonumber(userId),
|
||||
name = "User " .. userId,
|
||||
email = "user" .. userId .. "@example.com",
|
||||
active = true
|
||||
})
|
||||
end)
|
||||
|
||||
http.route("PUT", "/api/users/:id", function(req)
|
||||
local userId = req.params.id
|
||||
local userData = req.form
|
||||
|
||||
return http.json({
|
||||
success = true,
|
||||
message = "User " .. userId .. " updated",
|
||||
data = userData
|
||||
})
|
||||
end)
|
||||
|
||||
http.route("DELETE", "/api/users/:id", function(req)
|
||||
local userId = req.params.id
|
||||
return http.json({
|
||||
success = true,
|
||||
message = "User " .. userId .. " deleted"
|
||||
})
|
||||
end)
|
||||
|
||||
-- Login form with CSRF protection
|
||||
http.route("GET", "/login", function(req)
|
||||
return http.html([[
|
||||
<h2>Login</h2>
|
||||
<form method="POST" action="/login">
|
||||
]] .. csrf.field() .. [[
|
||||
<input type="text" name="username" placeholder="Username" required><br>
|
||||
<input type="password" name="password" placeholder="Password" required><br>
|
||||
<button type="submit">Login</button>
|
||||
</form>
|
||||
<p><a href="/">Home</a></p>
|
||||
]])
|
||||
end)
|
||||
|
||||
-- Handle login POST
|
||||
http.route("POST", "/login", function(req)
|
||||
if not csrf.validate() then
|
||||
http.status(403)
|
||||
return "CSRF token invalid"
|
||||
end
|
||||
|
||||
local username = req.form.username
|
||||
local password = req.form.password
|
||||
|
||||
if username == "admin" and password == "secret" then
|
||||
session.set("user", username)
|
||||
session.flash("success", "Login successful!")
|
||||
return http.redirect("/dashboard")
|
||||
else
|
||||
session.flash("error", "Invalid credentials")
|
||||
return http.redirect("/login")
|
||||
end
|
||||
end)
|
||||
|
||||
-- Dashboard (requires login)
|
||||
http.route("GET", "/dashboard", function(req)
|
||||
local user = session.get("user")
|
||||
if not user then
|
||||
return http.redirect("/login")
|
||||
end
|
||||
|
||||
local success = session.get_flash("success")
|
||||
local error = session.get_flash("error")
|
||||
|
||||
return http.html([[
|
||||
<h2>Dashboard</h2>
|
||||
]] .. (success and ("<p style='color:green'>" .. success .. "</p>") or "") .. [[
|
||||
]] .. (error and ("<p style='color:red'>" .. error .. "</p>") or "") .. [[
|
||||
<p>Welcome, ]] .. user .. [[!</p>
|
||||
<p><a href="/logout">Logout</a></p>
|
||||
<p><a href="/">Home</a></p>
|
||||
]])
|
||||
end)
|
||||
|
||||
-- Logout
|
||||
http.route("GET", "/logout", function(req)
|
||||
session.set("user", nil)
|
||||
session.flash("info", "You have been logged out")
|
||||
return http.redirect("/")
|
||||
end)
|
||||
|
||||
-- Catch-all route for 404s (must be last)
|
||||
http.route("GET", "*path", function(req)
|
||||
http.status(404)
|
||||
return http.html([[
|
||||
<h1>404 - Page Not Found</h1>
|
||||
<p>The requested path "]] .. req.params.path .. [[" was not found.</p>
|
||||
<p><a href="/">Go Home</a></p>
|
||||
]])
|
||||
end)
|
||||
|
||||
print("Server configured with string routing. Listening on http://localhost:3000")
|
||||
print("Try these routes:")
|
||||
print(" GET /")
|
||||
print(" GET /users/123")
|
||||
print(" GET /users/456/posts")
|
||||
print(" GET /blog/my-first-post")
|
||||
print(" GET /blog/lua-tutorial/comments")
|
||||
print(" GET /files/docs/readme.txt")
|
||||
print(" GET /api/users/789")
|
||||
print(" GET /nonexistent (404 handler)")
|
217
utils/debug.go
217
utils/debug.go
@ -1,217 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html/template"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"Moonshark/config"
|
||||
"Moonshark/metadata"
|
||||
)
|
||||
|
||||
// ComponentStats holds stats from various system components
|
||||
type ComponentStats struct {
|
||||
RouteCount int // Number of routes
|
||||
BytecodeBytes int64 // Total size of bytecode in bytes
|
||||
ModuleCount int // Number of loaded modules
|
||||
SessionStats map[string]uint64 // Session cache statistics
|
||||
}
|
||||
|
||||
// SystemStats represents system statistics for debugging
|
||||
type SystemStats struct {
|
||||
Timestamp time.Time
|
||||
GoVersion string
|
||||
GoRoutines int
|
||||
Memory runtime.MemStats
|
||||
Components ComponentStats
|
||||
Version string
|
||||
Config *config.Config
|
||||
}
|
||||
|
||||
// CollectSystemStats gathers basic system statistics
|
||||
func CollectSystemStats(cfg *config.Config) SystemStats {
|
||||
var stats SystemStats
|
||||
var mem runtime.MemStats
|
||||
|
||||
stats.Timestamp = time.Now()
|
||||
stats.GoVersion = runtime.Version()
|
||||
stats.GoRoutines = runtime.NumGoroutine()
|
||||
stats.Version = metadata.Version
|
||||
stats.Config = cfg
|
||||
|
||||
runtime.ReadMemStats(&mem)
|
||||
stats.Memory = mem
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// DebugStatsPage generates an HTML debug stats page
|
||||
func DebugStatsPage(stats SystemStats) string {
|
||||
const debugTemplate = `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Moonshark</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: sans-serif;
|
||||
max-width: 900px;
|
||||
margin: 0 auto;
|
||||
background-color: #333;
|
||||
color: white;
|
||||
}
|
||||
h1 {
|
||||
padding: 1rem;
|
||||
background-color: #4F5B93;
|
||||
box-shadow: 0 2px 4px 0px rgba(0, 0, 0, 0.2);
|
||||
margin-top: 0;
|
||||
}
|
||||
h2 { margin-top: 0; margin-bottom: 0.5rem; }
|
||||
table { width: 100%; border-collapse: collapse; }
|
||||
th { width: 1%; white-space: nowrap; border-right: 1px solid rgba(0, 0, 0, 0.1); }
|
||||
th, td { text-align: left; padding: 0.75rem 0.5rem; border-bottom: 1px solid #ddd; }
|
||||
tr:last-child th, tr:last-child td { border-bottom: none; }
|
||||
table tr:nth-child(even), tbody tr:nth-child(even) { background-color: rgba(0, 0, 0, 0.1); }
|
||||
.card {
|
||||
background: #F2F2F2;
|
||||
color: #333;
|
||||
border-radius: 4px;
|
||||
margin-bottom: 1rem;
|
||||
box-shadow: 0 2px 4px 0px rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
.timestamp { color: #999; font-size: 0.9em; margin-bottom: 1rem; }
|
||||
.section { margin-bottom: 30px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Moonshark</h1>
|
||||
<div class="timestamp">Generated at: {{.Timestamp.Format "2006-01-02 15:04:05"}}</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Server</h2>
|
||||
<div class="card">
|
||||
<table>
|
||||
<tr><th>Version</th><td>{{.Version}}</td></tr>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>System</h2>
|
||||
<div class="card">
|
||||
<table>
|
||||
<tr><th>Go Version</th><td>{{.GoVersion}}</td></tr>
|
||||
<tr><th>Goroutines</th><td>{{.GoRoutines}}</td></tr>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Memory</h2>
|
||||
<div class="card">
|
||||
<table>
|
||||
<tr><th>Allocated</th><td>{{ByteCount .Memory.Alloc}}</td></tr>
|
||||
<tr><th>Total Allocated</th><td>{{ByteCount .Memory.TotalAlloc}}</td></tr>
|
||||
<tr><th>System Memory</th><td>{{ByteCount .Memory.Sys}}</td></tr>
|
||||
<tr><th>GC Cycles</th><td>{{.Memory.NumGC}}</td></tr>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Sessions</h2>
|
||||
<div class="card">
|
||||
<table>
|
||||
<tr><th>Active Sessions</th><td>{{index .Components.SessionStats "entries"}}</td></tr>
|
||||
<tr><th>Cache Size</th><td>{{ByteCount (index .Components.SessionStats "bytes")}}</td></tr>
|
||||
<tr><th>Max Cache Size</th><td>{{ByteCount (index .Components.SessionStats "max_bytes")}}</td></tr>
|
||||
<tr><th>Cache Gets</th><td>{{index .Components.SessionStats "gets"}}</td></tr>
|
||||
<tr><th>Cache Sets</th><td>{{index .Components.SessionStats "sets"}}</td></tr>
|
||||
<tr><th>Cache Misses</th><td>{{index .Components.SessionStats "misses"}}</td></tr>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>LuaRunner</h2>
|
||||
<div class="card">
|
||||
<table>
|
||||
<tr><th>Interpreter</th><td>LuaJIT 2.1 (Lua 5.1)</td></tr>
|
||||
<tr><th>Active Routes</th><td>{{.Components.RouteCount}}</td></tr>
|
||||
<tr><th>Bytecode Size</th><td>{{ByteCount .Components.BytecodeBytes}}</td></tr>
|
||||
<tr><th>Loaded Modules</th><td>{{.Components.ModuleCount}}</td></tr>
|
||||
<tr><th>State Pool Size</th><td>{{.Config.Runner.PoolSize}}</td></tr>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Config</h2>
|
||||
<div class="card">
|
||||
<table>
|
||||
<tr><th>Port</th><td>{{.Config.Server.Port}}</td></tr>
|
||||
<tr><th>Pool Size</th><td>{{.Config.Runner.PoolSize}}</td></tr>
|
||||
<tr><th>Debug Mode</th><td>{{.Config.Server.Debug}}</td></tr>
|
||||
<tr><th>Log Level</th><td>{{.Config.Server.LogLevel}}</td></tr>
|
||||
<tr><th>HTTP Logging</th><td>{{.Config.Server.HTTPLogging}}</td></tr>
|
||||
<tr>
|
||||
<th>Directories</th>
|
||||
<td>
|
||||
<div>Routes: {{.Config.Dirs.Routes}}</div>
|
||||
<div>Static: {{.Config.Dirs.Static}}</div>
|
||||
<div>Override: {{.Config.Dirs.Override}}</div>
|
||||
<div>Libs: {{range .Config.Dirs.Libs}}{{.}}, {{end}}</div>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`
|
||||
|
||||
// Create a template function map
|
||||
funcMap := template.FuncMap{
|
||||
"ByteCount": func(b any) string {
|
||||
var bytes uint64
|
||||
|
||||
switch v := b.(type) {
|
||||
case uint64:
|
||||
bytes = v
|
||||
case int64:
|
||||
bytes = uint64(v)
|
||||
case int:
|
||||
bytes = uint64(v)
|
||||
default:
|
||||
return fmt.Sprintf("%T: %v", b, b)
|
||||
}
|
||||
|
||||
const unit = 1024
|
||||
if bytes < unit {
|
||||
return fmt.Sprintf("%d B", bytes)
|
||||
}
|
||||
div, exp := uint64(unit), 0
|
||||
for n := bytes / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||
},
|
||||
}
|
||||
|
||||
// Parse the template
|
||||
tmpl, err := template.New("debug").Funcs(funcMap).Parse(debugTemplate)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error parsing template: %v", err)
|
||||
}
|
||||
|
||||
// Execute the template
|
||||
var output strings.Builder
|
||||
if err := tmpl.Execute(&output, stats); err != nil {
|
||||
return fmt.Sprintf("Error executing template: %v", err)
|
||||
}
|
||||
|
||||
return output.String()
|
||||
}
|
@ -1,231 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
)
|
||||
|
||||
var dbg bool
|
||||
|
||||
// ErrorType represents HTTP error types
|
||||
type ErrorType int
|
||||
|
||||
const (
|
||||
ErrorTypeNotFound ErrorType = 404
|
||||
ErrorTypeMethodNotAllowed ErrorType = 405
|
||||
ErrorTypeInternalError ErrorType = 500
|
||||
ErrorTypeForbidden ErrorType = 403 // Added CSRF/Forbidden error type
|
||||
)
|
||||
|
||||
func Debug(enabled bool) { dbg = enabled }
|
||||
|
||||
// ErrorPage generates an HTML error page based on the error type
|
||||
// It first checks for an override file, and if not found, generates a default page
|
||||
func ErrorPage(errorType ErrorType, url string, errMsg string) string {
|
||||
// No override found, generate default page
|
||||
switch errorType {
|
||||
case ErrorTypeNotFound:
|
||||
return generateNotFoundHTML(url)
|
||||
case ErrorTypeMethodNotAllowed:
|
||||
return generateMethodNotAllowedHTML(url)
|
||||
case ErrorTypeInternalError:
|
||||
return generateInternalErrorHTML(dbg, url, errMsg)
|
||||
case ErrorTypeForbidden:
|
||||
return generateForbiddenHTML(dbg, url, errMsg)
|
||||
default:
|
||||
// Fallback to internal error
|
||||
return generateInternalErrorHTML(dbg, url, errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// NotFoundPage generates a 404 Not Found error page
|
||||
func NotFoundPage(url string) string {
|
||||
return ErrorPage(ErrorTypeNotFound, url, "")
|
||||
}
|
||||
|
||||
// MethodNotAllowedPage generates a 405 Method Not Allowed error page
|
||||
func MethodNotAllowedPage(url string) string {
|
||||
return ErrorPage(ErrorTypeMethodNotAllowed, url, "")
|
||||
}
|
||||
|
||||
// InternalErrorPage generates a 500 Internal Server Error page
|
||||
func InternalErrorPage(url string, errMsg string) string {
|
||||
return ErrorPage(ErrorTypeInternalError, url, errMsg)
|
||||
}
|
||||
|
||||
// ForbiddenPage generates a 403 Forbidden error page
|
||||
func ForbiddenPage(url string, errMsg string) string {
|
||||
return ErrorPage(ErrorTypeForbidden, url, errMsg)
|
||||
}
|
||||
|
||||
// generateInternalErrorHTML creates a 500 Internal Server Error page
|
||||
func generateInternalErrorHTML(debugMode bool, url string, errMsg string) string {
|
||||
errorMessages := []string{
|
||||
"Oops! Something went wrong",
|
||||
"Oh no! The server choked",
|
||||
"Well, this is embarrassing...",
|
||||
"Houston, we have a problem",
|
||||
"Gremlins in the system",
|
||||
"The server is taking a coffee break",
|
||||
"Moonshark encountered a lunar eclipse",
|
||||
"Our code monkeys are working on it",
|
||||
"The server is feeling under the weather",
|
||||
"500 Brain Not Found",
|
||||
}
|
||||
|
||||
randomMessage := errorMessages[rand.Intn(len(errorMessages))]
|
||||
return generateErrorHTML("500", randomMessage, "Internal Server Error", debugMode, errMsg)
|
||||
}
|
||||
|
||||
// generateForbiddenHTML creates a 403 Forbidden error page
|
||||
func generateForbiddenHTML(debugMode bool, url string, errMsg string) string {
|
||||
errorMessages := []string{
|
||||
"Access denied",
|
||||
"You shall not pass",
|
||||
"This area is off-limits",
|
||||
"Security check failed",
|
||||
"Invalid security token",
|
||||
"Request blocked for security reasons",
|
||||
"Permission denied",
|
||||
"Security violation detected",
|
||||
"This request was rejected",
|
||||
"Security first, access second",
|
||||
}
|
||||
|
||||
defaultMsg := "Invalid or missing CSRF token. This could be due to an expired form or a cross-site request forgery attempt."
|
||||
if errMsg == "" {
|
||||
errMsg = defaultMsg
|
||||
}
|
||||
|
||||
randomMessage := errorMessages[rand.Intn(len(errorMessages))]
|
||||
return generateErrorHTML("403", randomMessage, "Forbidden", debugMode, errMsg)
|
||||
}
|
||||
|
||||
// generateNotFoundHTML creates a 404 Not Found error page
|
||||
func generateNotFoundHTML(url string) string {
|
||||
errorMessages := []string{
|
||||
"Nothing to see here",
|
||||
"This page is on vacation",
|
||||
"The page is missing in action",
|
||||
"This page has left the building",
|
||||
"This page is in another castle",
|
||||
"Sorry, we can't find that",
|
||||
"The page you're looking for doesn't exist",
|
||||
"Lost in space",
|
||||
"That's a 404",
|
||||
"Page not found",
|
||||
}
|
||||
|
||||
randomMessage := errorMessages[rand.Intn(len(errorMessages))]
|
||||
return generateErrorHTML("404", randomMessage, "Page Not Found", false, url)
|
||||
}
|
||||
|
||||
// generateMethodNotAllowedHTML creates a 405 Method Not Allowed error page
|
||||
func generateMethodNotAllowedHTML(url string) string {
|
||||
errorMessages := []string{
|
||||
"That's not how this works",
|
||||
"Method not allowed",
|
||||
"Wrong way!",
|
||||
"This method is not supported",
|
||||
"You can't do that here",
|
||||
"Sorry, wrong door",
|
||||
"That method won't work here",
|
||||
"Try a different approach",
|
||||
"Access denied for this method",
|
||||
"Method mismatch",
|
||||
}
|
||||
|
||||
randomMessage := errorMessages[rand.Intn(len(errorMessages))]
|
||||
return generateErrorHTML("405", randomMessage, "Method Not Allowed", false, url)
|
||||
}
|
||||
|
||||
// generateErrorHTML creates the common HTML structure for error pages
|
||||
func generateErrorHTML(errorCode, mainMessage, subMessage string, showDebugInfo bool, codeContent string) string {
|
||||
errorHTML := `<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<title>` + errorCode + `</title>
|
||||
<style>
|
||||
:root {
|
||||
--bg-color: #2d2e2d;
|
||||
--bg-gradient: linear-gradient(to bottom, #2d2e2d 0%, #000 100%);
|
||||
--text-color: white;
|
||||
--code-bg: rgba(255, 255, 255, 0.1);
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: light) {
|
||||
:root {
|
||||
--bg-color: #f5f5f5;
|
||||
--bg-gradient: linear-gradient(to bottom, #f5f5f5 0%, #ddd 100%);
|
||||
--text-color: #333;
|
||||
--code-bg: rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: sans-serif;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
height: 100vh;
|
||||
margin: 0;
|
||||
background-color: var(--bg-color);
|
||||
color: var(--text-color);
|
||||
background: var(--bg-gradient);
|
||||
}
|
||||
h1 {
|
||||
font-size: 4rem;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
p {
|
||||
font-size: 1.5rem;
|
||||
margin: 0.5rem 0;
|
||||
padding: 0;
|
||||
}
|
||||
.sub-message {
|
||||
font-size: 1.2rem;
|
||||
margin-bottom: 1rem;
|
||||
opacity: 0.8;
|
||||
}
|
||||
code {
|
||||
display: inline-block;
|
||||
font-size: 1rem;
|
||||
font-family: monospace;
|
||||
background-color: var(--code-bg);
|
||||
padding: 0.25em 0.5em;
|
||||
border-radius: 0.25em;
|
||||
margin-top: 1rem;
|
||||
max-width: 90vw;
|
||||
overflow-wrap: break-word;
|
||||
word-break: break-all;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div>
|
||||
<h1>` + errorCode + `</h1>
|
||||
<p>` + mainMessage + `</p>
|
||||
<div class="sub-message">` + subMessage + `</div>`
|
||||
|
||||
if codeContent != "" {
|
||||
errorHTML += `
|
||||
<code>` + codeContent + `</code>`
|
||||
}
|
||||
|
||||
// Add a note for debug mode
|
||||
if showDebugInfo {
|
||||
errorHTML += `
|
||||
<p style="font-size: 0.9rem; margin-top: 1rem;">
|
||||
An error occurred while processing your request.<br>
|
||||
Please check the server logs for details.
|
||||
</p>`
|
||||
}
|
||||
|
||||
errorHTML += `
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
return errorHTML
|
||||
}
|
@ -1,115 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func ParseForm(ctx *fasthttp.RequestCtx) (map[string]any, error) {
|
||||
contentType := string(ctx.Request.Header.ContentType())
|
||||
formData := make(map[string]any)
|
||||
|
||||
switch {
|
||||
case strings.Contains(contentType, "multipart/form-data"):
|
||||
if err := parseMultipartInto(ctx, formData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
case strings.Contains(contentType, "application/x-www-form-urlencoded"):
|
||||
args := ctx.PostArgs()
|
||||
args.VisitAll(func(key, value []byte) {
|
||||
appendValue(formData, string(key), string(value))
|
||||
})
|
||||
|
||||
case strings.Contains(contentType, "application/json"):
|
||||
if err := json.Unmarshal(ctx.PostBody(), &formData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
default:
|
||||
// Leave formData empty if content-type is unrecognized
|
||||
}
|
||||
|
||||
return formData, nil
|
||||
}
|
||||
|
||||
func parseMultipartInto(ctx *fasthttp.RequestCtx, formData map[string]any) error {
|
||||
form, err := ctx.MultipartForm()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for key, values := range form.Value {
|
||||
if len(values) == 1 {
|
||||
formData[key] = values[0]
|
||||
} else if len(values) > 1 {
|
||||
formData[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
if len(form.File) > 0 {
|
||||
files := make(map[string]any, len(form.File))
|
||||
for fieldName, fileHeaders := range form.File {
|
||||
if len(fileHeaders) == 1 {
|
||||
files[fieldName] = fileInfoToMap(fileHeaders[0])
|
||||
} else {
|
||||
fileInfos := make([]map[string]any, len(fileHeaders))
|
||||
for i, fh := range fileHeaders {
|
||||
fileInfos[i] = fileInfoToMap(fh)
|
||||
}
|
||||
files[fieldName] = fileInfos
|
||||
}
|
||||
}
|
||||
formData["_files"] = files
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func appendValue(formData map[string]any, key, value string) {
|
||||
if existing, exists := formData[key]; exists {
|
||||
switch v := existing.(type) {
|
||||
case string:
|
||||
formData[key] = []string{v, value}
|
||||
case []string:
|
||||
formData[key] = append(v, value)
|
||||
}
|
||||
} else {
|
||||
formData[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func fileInfoToMap(fh *multipart.FileHeader) map[string]any {
|
||||
ct := fh.Header.Get("Content-Type")
|
||||
if ct == "" {
|
||||
ct = getMimeType(fh.Filename)
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"filename": fh.Filename,
|
||||
"size": fh.Size,
|
||||
"mimetype": ct,
|
||||
}
|
||||
}
|
||||
|
||||
func getMimeType(filename string) string {
|
||||
if i := strings.LastIndex(filename, "."); i >= 0 {
|
||||
switch filename[i:] {
|
||||
case ".pdf":
|
||||
return "application/pdf"
|
||||
case ".png":
|
||||
return "image/png"
|
||||
case ".jpg", ".jpeg":
|
||||
return "image/jpeg"
|
||||
case ".gif":
|
||||
return "image/gif"
|
||||
case ".svg":
|
||||
return "image/svg+xml"
|
||||
}
|
||||
}
|
||||
return "application/octet-stream"
|
||||
}
|
@ -1,92 +0,0 @@
|
||||
package watchers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"Moonshark/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDirectoryNotFound = errors.New("directory not found")
|
||||
ErrAlreadyWatching = errors.New("already watching directory")
|
||||
)
|
||||
|
||||
// WatcherManager now just manages watcher lifecycle - no polling logic
|
||||
type WatcherManager struct {
|
||||
watchers map[string]*Watcher
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewWatcherManager() *WatcherManager {
|
||||
return &WatcherManager{
|
||||
watchers: make(map[string]*Watcher),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *WatcherManager) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for _, watcher := range m.watchers {
|
||||
watcher.Close()
|
||||
}
|
||||
|
||||
m.watchers = make(map[string]*Watcher)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *WatcherManager) WatchDirectory(config WatcherConfig) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.watchers[config.Dir]; exists {
|
||||
return ErrAlreadyWatching
|
||||
}
|
||||
|
||||
watcher, err := NewWatcher(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.watchers[config.Dir] = watcher
|
||||
logger.Debugf("WatcherManager added watcher for %s", config.Dir)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *WatcherManager) UnwatchDirectory(dir string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
watcher, exists := m.watchers[dir]
|
||||
if !exists {
|
||||
return ErrDirectoryNotFound
|
||||
}
|
||||
|
||||
watcher.Close()
|
||||
delete(m.watchers, dir)
|
||||
logger.Debugf("WatcherManager removed watcher for %s", dir)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *WatcherManager) GetWatcher(dir string) (*Watcher, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
watcher, exists := m.watchers[dir]
|
||||
return watcher, exists
|
||||
}
|
||||
|
||||
func (m *WatcherManager) ListWatching() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
dirs := make([]string, 0, len(m.watchers))
|
||||
for dir := range m.watchers {
|
||||
dirs = append(dirs, dir)
|
||||
}
|
||||
|
||||
return dirs
|
||||
}
|
@ -1,236 +0,0 @@
|
||||
package watchers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"Moonshark/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDebounceTime = 300 * time.Millisecond
|
||||
defaultPollInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
type FileChange struct {
|
||||
Path string
|
||||
IsNew bool
|
||||
IsDeleted bool
|
||||
}
|
||||
|
||||
type FileInfo struct {
|
||||
ModTime time.Time
|
||||
}
|
||||
|
||||
// Watcher is now self-contained and manages its own polling
|
||||
type Watcher struct {
|
||||
dir string
|
||||
files map[string]FileInfo
|
||||
filesMu sync.RWMutex
|
||||
callback func([]FileChange) error
|
||||
debounceTime time.Duration
|
||||
pollInterval time.Duration
|
||||
recursive bool
|
||||
|
||||
// Self-management
|
||||
done chan struct{}
|
||||
debounceTimer *time.Timer
|
||||
debouncing bool
|
||||
debounceMu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
type WatcherConfig struct {
|
||||
Dir string
|
||||
Callback func([]FileChange) error
|
||||
DebounceTime time.Duration
|
||||
PollInterval time.Duration
|
||||
Recursive bool
|
||||
}
|
||||
|
||||
func NewWatcher(config WatcherConfig) (*Watcher, error) {
|
||||
if config.DebounceTime == 0 {
|
||||
config.DebounceTime = defaultDebounceTime
|
||||
}
|
||||
if config.PollInterval == 0 {
|
||||
config.PollInterval = defaultPollInterval
|
||||
}
|
||||
|
||||
w := &Watcher{
|
||||
dir: config.Dir,
|
||||
files: make(map[string]FileInfo),
|
||||
callback: config.Callback,
|
||||
debounceTime: config.DebounceTime,
|
||||
pollInterval: config.PollInterval,
|
||||
recursive: config.Recursive,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
if err := w.scanDirectory(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
w.wg.Add(1)
|
||||
go w.watchLoop()
|
||||
|
||||
return w, nil
|
||||
}
|
||||
|
||||
func (w *Watcher) Close() {
|
||||
close(w.done)
|
||||
w.wg.Wait()
|
||||
|
||||
w.debounceMu.Lock()
|
||||
if w.debounceTimer != nil {
|
||||
w.debounceTimer.Stop()
|
||||
}
|
||||
w.debounceMu.Unlock()
|
||||
}
|
||||
|
||||
func (w *Watcher) GetDir() string {
|
||||
return w.dir
|
||||
}
|
||||
|
||||
// watchLoop is the main polling loop for this watcher
|
||||
func (w *Watcher) watchLoop() {
|
||||
defer w.wg.Done()
|
||||
ticker := time.NewTicker(w.pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
w.checkAndNotify()
|
||||
case <-w.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkAndNotify combines change detection and notification
|
||||
func (w *Watcher) checkAndNotify() {
|
||||
changed, changedFiles := w.detectChanges()
|
||||
if changed {
|
||||
w.notifyChange(changedFiles)
|
||||
}
|
||||
}
|
||||
|
||||
// detectChanges scans directory and returns changes
|
||||
func (w *Watcher) detectChanges() (bool, []FileChange) {
|
||||
newFiles := make(map[string]FileInfo)
|
||||
var changedFiles []FileChange
|
||||
changed := false
|
||||
|
||||
err := filepath.Walk(w.dir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !w.recursive && info.IsDir() && path != w.dir {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
currentInfo := FileInfo{ModTime: info.ModTime()}
|
||||
newFiles[path] = currentInfo
|
||||
|
||||
w.filesMu.RLock()
|
||||
prevInfo, exists := w.files[path]
|
||||
w.filesMu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
changed = true
|
||||
changedFiles = append(changedFiles, FileChange{Path: path, IsNew: true})
|
||||
w.logDebug("File added: %s", path)
|
||||
} else if currentInfo.ModTime != prevInfo.ModTime {
|
||||
changed = true
|
||||
changedFiles = append(changedFiles, FileChange{Path: path})
|
||||
w.logDebug("File changed: %s", path)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
w.logError("Error scanning directory: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check for deletions
|
||||
w.filesMu.RLock()
|
||||
for path := range w.files {
|
||||
if _, exists := newFiles[path]; !exists {
|
||||
changed = true
|
||||
changedFiles = append(changedFiles, FileChange{Path: path, IsDeleted: true})
|
||||
w.logDebug("File deleted: %s", path)
|
||||
}
|
||||
}
|
||||
w.filesMu.RUnlock()
|
||||
|
||||
if changed {
|
||||
w.filesMu.Lock()
|
||||
w.files = newFiles
|
||||
w.filesMu.Unlock()
|
||||
}
|
||||
|
||||
return changed, changedFiles
|
||||
}
|
||||
|
||||
func (w *Watcher) scanDirectory() error {
|
||||
w.filesMu.Lock()
|
||||
defer w.filesMu.Unlock()
|
||||
|
||||
w.files = make(map[string]FileInfo)
|
||||
|
||||
return filepath.Walk(w.dir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !w.recursive && info.IsDir() && path != w.dir {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
w.files[path] = FileInfo{ModTime: info.ModTime()}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (w *Watcher) notifyChange(changedFiles []FileChange) {
|
||||
w.debounceMu.Lock()
|
||||
defer w.debounceMu.Unlock()
|
||||
|
||||
if w.debouncing && w.debounceTimer != nil {
|
||||
w.debounceTimer.Stop()
|
||||
}
|
||||
w.debouncing = true
|
||||
|
||||
// Copy to avoid race conditions
|
||||
filesCopy := make([]FileChange, len(changedFiles))
|
||||
copy(filesCopy, changedFiles)
|
||||
|
||||
w.debounceTimer = time.AfterFunc(w.debounceTime, func() {
|
||||
var err error
|
||||
if w.callback != nil {
|
||||
err = w.callback(filesCopy)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
w.logError("Callback error: %v", err)
|
||||
}
|
||||
|
||||
w.debounceMu.Lock()
|
||||
w.debouncing = false
|
||||
w.debounceMu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
func (w *Watcher) logDebug(format string, args ...any) {
|
||||
logger.Debugf("[Watcher] [%s] %s", w.dir, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (w *Watcher) logError(format string, args ...any) {
|
||||
logger.Errorf("[Watcher] [%s] %s", w.dir, fmt.Sprintf(format, args...))
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user