migrate to new LJTG API
This commit is contained in:
parent
cc6a7675d8
commit
e2b1b932ff
155
runner/crypto.go
155
runner/crypto.go
@ -78,16 +78,21 @@ func CleanupCrypto(state *luajit.State) {
|
|||||||
|
|
||||||
// cryptoHash generates hash digests using various algorithms
|
// cryptoHash generates hash digests using various algorithms
|
||||||
func cryptoHash(state *luajit.State) int {
|
func cryptoHash(state *luajit.State) int {
|
||||||
if !state.IsString(1) || !state.IsString(2) {
|
if err := state.CheckMinArgs(2); err != nil {
|
||||||
state.PushString("hash: expected (string data, string algorithm)")
|
return state.PushError("hash: %v", err)
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
data := state.ToString(1)
|
data, err := state.SafeToString(1)
|
||||||
algorithm := state.ToString(2)
|
if err != nil {
|
||||||
|
return state.PushError("hash: data must be string")
|
||||||
|
}
|
||||||
|
|
||||||
|
algorithm, err := state.SafeToString(2)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("hash: algorithm must be string")
|
||||||
|
}
|
||||||
|
|
||||||
var h hash.Hash
|
var h hash.Hash
|
||||||
|
|
||||||
switch algorithm {
|
switch algorithm {
|
||||||
case "md5":
|
case "md5":
|
||||||
h = md5.New()
|
h = md5.New()
|
||||||
@ -98,8 +103,7 @@ func cryptoHash(state *luajit.State) int {
|
|||||||
case "sha512":
|
case "sha512":
|
||||||
h = sha512.New()
|
h = sha512.New()
|
||||||
default:
|
default:
|
||||||
state.PushString(fmt.Sprintf("unsupported algorithm: %s", algorithm))
|
return state.PushError("unsupported algorithm: %s", algorithm)
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
h.Write([]byte(data))
|
h.Write([]byte(data))
|
||||||
@ -107,8 +111,10 @@ func cryptoHash(state *luajit.State) int {
|
|||||||
|
|
||||||
// Output format
|
// Output format
|
||||||
outputFormat := "hex"
|
outputFormat := "hex"
|
||||||
if state.GetTop() >= 3 && state.IsString(3) {
|
if state.GetTop() >= 3 {
|
||||||
outputFormat = state.ToString(3)
|
if format, err := state.SafeToString(3); err == nil {
|
||||||
|
outputFormat = format
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch outputFormat {
|
switch outputFormat {
|
||||||
@ -125,17 +131,26 @@ func cryptoHash(state *luajit.State) int {
|
|||||||
|
|
||||||
// cryptoHmac generates HMAC using various hash algorithms
|
// cryptoHmac generates HMAC using various hash algorithms
|
||||||
func cryptoHmac(state *luajit.State) int {
|
func cryptoHmac(state *luajit.State) int {
|
||||||
if !state.IsString(1) || !state.IsString(2) || !state.IsString(3) {
|
if err := state.CheckMinArgs(3); err != nil {
|
||||||
state.PushString("hmac: expected (string data, string key, string algorithm)")
|
return state.PushError("hmac: %v", err)
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
data := state.ToString(1)
|
data, err := state.SafeToString(1)
|
||||||
key := state.ToString(2)
|
if err != nil {
|
||||||
algorithm := state.ToString(3)
|
return state.PushError("hmac: data must be string")
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := state.SafeToString(2)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("hmac: key must be string")
|
||||||
|
}
|
||||||
|
|
||||||
|
algorithm, err := state.SafeToString(3)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("hmac: algorithm must be string")
|
||||||
|
}
|
||||||
|
|
||||||
var h func() hash.Hash
|
var h func() hash.Hash
|
||||||
|
|
||||||
switch algorithm {
|
switch algorithm {
|
||||||
case "md5":
|
case "md5":
|
||||||
h = md5.New
|
h = md5.New
|
||||||
@ -146,8 +161,7 @@ func cryptoHmac(state *luajit.State) int {
|
|||||||
case "sha512":
|
case "sha512":
|
||||||
h = sha512.New
|
h = sha512.New
|
||||||
default:
|
default:
|
||||||
state.PushString(fmt.Sprintf("unsupported algorithm: %s", algorithm))
|
return state.PushError("unsupported algorithm: %s", algorithm)
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mac := hmac.New(h, []byte(key))
|
mac := hmac.New(h, []byte(key))
|
||||||
@ -156,8 +170,10 @@ func cryptoHmac(state *luajit.State) int {
|
|||||||
|
|
||||||
// Output format
|
// Output format
|
||||||
outputFormat := "hex"
|
outputFormat := "hex"
|
||||||
if state.GetTop() >= 4 && state.IsString(4) {
|
if state.GetTop() >= 4 {
|
||||||
outputFormat = state.ToString(4)
|
if format, err := state.SafeToString(4); err == nil {
|
||||||
|
outputFormat = format
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch outputFormat {
|
switch outputFormat {
|
||||||
@ -175,10 +191,8 @@ func cryptoHmac(state *luajit.State) int {
|
|||||||
// cryptoUuid generates a random UUID v4
|
// cryptoUuid generates a random UUID v4
|
||||||
func cryptoUuid(state *luajit.State) int {
|
func cryptoUuid(state *luajit.State) int {
|
||||||
uuid := make([]byte, 16)
|
uuid := make([]byte, 16)
|
||||||
_, err := rand.Read(uuid)
|
if _, err := rand.Read(uuid); err != nil {
|
||||||
if err != nil {
|
return state.PushError("uuid: generation error: %v", err)
|
||||||
state.PushString(fmt.Sprintf("uuid: generation error: %v", err))
|
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set version (4) and variant (RFC 4122)
|
// Set version (4) and variant (RFC 4122)
|
||||||
@ -194,15 +208,17 @@ func cryptoUuid(state *luajit.State) int {
|
|||||||
|
|
||||||
// cryptoRandomBytes generates random bytes
|
// cryptoRandomBytes generates random bytes
|
||||||
func cryptoRandomBytes(state *luajit.State) int {
|
func cryptoRandomBytes(state *luajit.State) int {
|
||||||
if !state.IsNumber(1) {
|
if err := state.CheckMinArgs(1); err != nil {
|
||||||
state.PushString("random_bytes: expected (number length)")
|
return state.PushError("random_bytes: %v", err)
|
||||||
return 1
|
}
|
||||||
|
|
||||||
|
length, err := state.SafeToNumber(1)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("random_bytes: length must be number")
|
||||||
}
|
}
|
||||||
|
|
||||||
length := int(state.ToNumber(1))
|
|
||||||
if length <= 0 {
|
if length <= 0 {
|
||||||
state.PushString("random_bytes: length must be positive")
|
return state.PushError("random_bytes: length must be positive")
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if secure
|
// Check if secure
|
||||||
@ -211,13 +227,11 @@ func cryptoRandomBytes(state *luajit.State) int {
|
|||||||
secure = state.ToBoolean(2)
|
secure = state.ToBoolean(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
bytes := make([]byte, length)
|
bytes := make([]byte, int(length))
|
||||||
|
|
||||||
if secure {
|
if secure {
|
||||||
_, err := rand.Read(bytes)
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
if err != nil {
|
return state.PushError("random_bytes: error: %v", err)
|
||||||
state.PushString(fmt.Sprintf("random_bytes: error: %v", err))
|
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
stateRngsMu.Lock()
|
stateRngsMu.Lock()
|
||||||
@ -225,8 +239,7 @@ func cryptoRandomBytes(state *luajit.State) int {
|
|||||||
stateRngsMu.Unlock()
|
stateRngsMu.Unlock()
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
state.PushString("random_bytes: RNG not initialized")
|
return state.PushError("random_bytes: RNG not initialized")
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range bytes {
|
for i := range bytes {
|
||||||
@ -236,8 +249,10 @@ func cryptoRandomBytes(state *luajit.State) int {
|
|||||||
|
|
||||||
// Output format
|
// Output format
|
||||||
outputFormat := "binary"
|
outputFormat := "binary"
|
||||||
if state.GetTop() >= 3 && state.IsString(3) {
|
if state.GetTop() >= 3 {
|
||||||
outputFormat = state.ToString(3)
|
if format, err := state.SafeToString(3); err == nil {
|
||||||
|
outputFormat = format
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch outputFormat {
|
switch outputFormat {
|
||||||
@ -254,17 +269,25 @@ func cryptoRandomBytes(state *luajit.State) int {
|
|||||||
|
|
||||||
// cryptoRandomInt generates a random integer in range [min, max]
|
// cryptoRandomInt generates a random integer in range [min, max]
|
||||||
func cryptoRandomInt(state *luajit.State) int {
|
func cryptoRandomInt(state *luajit.State) int {
|
||||||
if !state.IsNumber(1) || !state.IsNumber(2) {
|
if err := state.CheckMinArgs(2); err != nil {
|
||||||
state.PushString("random_int: expected (number min, number max)")
|
return state.PushError("random_int: %v", err)
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
min := int64(state.ToNumber(1))
|
minVal, err := state.SafeToNumber(1)
|
||||||
max := int64(state.ToNumber(2))
|
if err != nil {
|
||||||
|
return state.PushError("random_int: min must be number")
|
||||||
|
}
|
||||||
|
|
||||||
|
maxVal, err := state.SafeToNumber(2)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("random_int: max must be number")
|
||||||
|
}
|
||||||
|
|
||||||
|
min := int64(minVal)
|
||||||
|
max := int64(maxVal)
|
||||||
|
|
||||||
if max <= min {
|
if max <= min {
|
||||||
state.PushString("random_int: max must be greater than min")
|
return state.PushError("random_int: max must be greater than min")
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if secure
|
// Check if secure
|
||||||
@ -274,15 +297,12 @@ func cryptoRandomInt(state *luajit.State) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
range_size := max - min + 1
|
range_size := max - min + 1
|
||||||
|
|
||||||
var result int64
|
var result int64
|
||||||
|
|
||||||
if secure {
|
if secure {
|
||||||
bytes := make([]byte, 8)
|
bytes := make([]byte, 8)
|
||||||
_, err := rand.Read(bytes)
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
if err != nil {
|
return state.PushError("random_int: error: %v", err)
|
||||||
state.PushString(fmt.Sprintf("random_int: error: %v", err))
|
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
val := binary.BigEndian.Uint64(bytes)
|
val := binary.BigEndian.Uint64(bytes)
|
||||||
@ -293,8 +313,7 @@ func cryptoRandomInt(state *luajit.State) int {
|
|||||||
stateRngsMu.Unlock()
|
stateRngsMu.Unlock()
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
state.PushString("random_int: RNG not initialized")
|
return state.PushError("random_int: RNG not initialized")
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result = min + int64(stateRng.Uint64()%uint64(range_size))
|
result = min + int64(stateRng.Uint64()%uint64(range_size))
|
||||||
@ -315,10 +334,8 @@ func cryptoRandom(state *luajit.State) int {
|
|||||||
if numArgs == 0 {
|
if numArgs == 0 {
|
||||||
if secure {
|
if secure {
|
||||||
bytes := make([]byte, 8)
|
bytes := make([]byte, 8)
|
||||||
_, err := rand.Read(bytes)
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
if err != nil {
|
return state.PushError("random: error: %v", err)
|
||||||
state.PushString(fmt.Sprintf("random: error: %v", err))
|
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
val := binary.BigEndian.Uint64(bytes)
|
val := binary.BigEndian.Uint64(bytes)
|
||||||
state.PushNumber(float64(val) / float64(math.MaxUint64))
|
state.PushNumber(float64(val) / float64(math.MaxUint64))
|
||||||
@ -328,8 +345,7 @@ func cryptoRandom(state *luajit.State) int {
|
|||||||
stateRngsMu.Unlock()
|
stateRngsMu.Unlock()
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
state.PushString("random: RNG not initialized")
|
return state.PushError("random: RNG not initialized")
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
state.PushNumber(float64(stateRng.Uint64()) / float64(math.MaxUint64))
|
state.PushNumber(float64(stateRng.Uint64()) / float64(math.MaxUint64))
|
||||||
@ -341,8 +357,7 @@ func cryptoRandom(state *luajit.State) int {
|
|||||||
if numArgs == 1 && state.IsNumber(1) {
|
if numArgs == 1 && state.IsNumber(1) {
|
||||||
n := int64(state.ToNumber(1))
|
n := int64(state.ToNumber(1))
|
||||||
if n < 1 {
|
if n < 1 {
|
||||||
state.PushString("random: upper bound must be >= 1")
|
return state.PushError("random: upper bound must be >= 1")
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
state.PushNumber(1) // min
|
state.PushNumber(1) // min
|
||||||
@ -357,21 +372,23 @@ func cryptoRandom(state *luajit.State) int {
|
|||||||
return cryptoRandomInt(state)
|
return cryptoRandomInt(state)
|
||||||
}
|
}
|
||||||
|
|
||||||
state.PushString("random: invalid arguments")
|
return state.PushError("random: invalid arguments")
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// cryptoRandomSeed sets seed for non-secure RNG
|
// cryptoRandomSeed sets seed for non-secure RNG
|
||||||
func cryptoRandomSeed(state *luajit.State) int {
|
func cryptoRandomSeed(state *luajit.State) int {
|
||||||
if !state.IsNumber(1) {
|
if err := state.CheckExactArgs(1); err != nil {
|
||||||
state.PushString("randomseed: expected (number seed)")
|
return state.PushError("randomseed: %v", err)
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
seed := uint64(state.ToNumber(1))
|
seed, err := state.SafeToNumber(1)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("randomseed: seed must be number")
|
||||||
|
}
|
||||||
|
|
||||||
|
seedVal := uint64(seed)
|
||||||
stateRngsMu.Lock()
|
stateRngsMu.Lock()
|
||||||
stateRngs[state] = mrand.NewPCG(seed, seed>>32)
|
stateRngs[state] = mrand.NewPCG(seedVal, seedVal>>32)
|
||||||
stateRngsMu.Unlock()
|
stateRngsMu.Unlock()
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
@ -71,7 +71,7 @@ var (
|
|||||||
// precompileModule compiles a module's code to bytecode once
|
// precompileModule compiles a module's code to bytecode once
|
||||||
func precompileModule(m *ModuleInfo) {
|
func precompileModule(m *ModuleInfo) {
|
||||||
m.Once.Do(func() {
|
m.Once.Do(func() {
|
||||||
tempState := luajit.New()
|
tempState := luajit.New(true) // Explicitly open standard libraries
|
||||||
if tempState == nil {
|
if tempState == nil {
|
||||||
logger.Fatalf("Failed to create temp Lua state for %s module compilation", m.Name)
|
logger.Fatalf("Failed to create temp Lua state for %s module compilation", m.Name)
|
||||||
return
|
return
|
||||||
@ -105,18 +105,14 @@ func loadModule(state *luajit.State, m *ModuleInfo, verbose bool) error {
|
|||||||
logger.Debugf("Loading %s.lua from precompiled bytecode", m.Name)
|
logger.Debugf("Loading %s.lua from precompiled bytecode", m.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := state.LoadBytecode(*bytecode, m.Name+".lua"); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.DefinesGlobal {
|
if m.DefinesGlobal {
|
||||||
// Module defines its own globals, just run it
|
// Module defines its own globals, just run it
|
||||||
if err := state.RunBytecode(); err != nil {
|
if err := state.LoadAndRunBytecode(*bytecode, m.Name+".lua"); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Module returns a table, capture and set as global
|
// Module returns a table, capture and set as global
|
||||||
if err := state.RunBytecodeWithResults(1); err != nil {
|
if err := state.LoadAndRunBytecodeWithResults(*bytecode, m.Name+".lua", 1); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
state.SetGlobal(m.Name)
|
state.SetGlobal(m.Name)
|
||||||
|
@ -221,12 +221,17 @@ func CleanupEnv() error {
|
|||||||
|
|
||||||
// envGet Lua function to get an environment variable
|
// envGet Lua function to get an environment variable
|
||||||
func envGet(state *luajit.State) int {
|
func envGet(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckExactArgs(1); err != nil {
|
||||||
|
state.PushNil()
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := state.SafeToString(1)
|
||||||
|
if err != nil {
|
||||||
state.PushNil()
|
state.PushNil()
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
key := state.ToString(1)
|
|
||||||
if value, exists := globalEnvManager.Get(key); exists {
|
if value, exists := globalEnvManager.Get(key); exists {
|
||||||
if err := state.PushValue(value); err != nil {
|
if err := state.PushValue(value); err != nil {
|
||||||
state.PushNil()
|
state.PushNil()
|
||||||
@ -239,13 +244,22 @@ func envGet(state *luajit.State) int {
|
|||||||
|
|
||||||
// envSet Lua function to set an environment variable
|
// envSet Lua function to set an environment variable
|
||||||
func envSet(state *luajit.State) int {
|
func envSet(state *luajit.State) int {
|
||||||
if !state.IsString(1) || !state.IsString(2) {
|
if err := state.CheckExactArgs(2); err != nil {
|
||||||
state.PushBoolean(false)
|
state.PushBoolean(false)
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
key := state.ToString(1)
|
key, err := state.SafeToString(1)
|
||||||
value := state.ToString(2)
|
if err != nil {
|
||||||
|
state.PushBoolean(false)
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err := state.SafeToString(2)
|
||||||
|
if err != nil {
|
||||||
|
state.PushBoolean(false)
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
globalEnvManager.Set(key, value)
|
globalEnvManager.Set(key, value)
|
||||||
state.PushBoolean(true)
|
state.PushBoolean(true)
|
||||||
@ -255,11 +269,9 @@ func envSet(state *luajit.State) int {
|
|||||||
// envGetAll Lua function to get all environment variables
|
// envGetAll Lua function to get all environment variables
|
||||||
func envGetAll(state *luajit.State) int {
|
func envGetAll(state *luajit.State) int {
|
||||||
vars := globalEnvManager.GetAll()
|
vars := globalEnvManager.GetAll()
|
||||||
|
if err := state.PushValue(vars); err != nil {
|
||||||
if err := state.PushTable(vars); err != nil {
|
|
||||||
state.PushNil()
|
state.PushNil()
|
||||||
}
|
}
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
308
runner/fs.go
308
runner/fs.go
@ -112,23 +112,24 @@ func getCacheKey(fullPath string, modTime time.Time) string {
|
|||||||
|
|
||||||
// fsReadFile reads a file and returns its contents
|
// fsReadFile reads a file and returns its contents
|
||||||
func fsReadFile(state *luajit.State) int {
|
func fsReadFile(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckExactArgs(1); err != nil {
|
||||||
state.PushString("fs.read_file: path must be a string")
|
return state.PushError("fs.read_file: %v", err)
|
||||||
return -1
|
}
|
||||||
|
|
||||||
|
path, err := state.SafeToString(1)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("fs.read_file: path must be string")
|
||||||
}
|
}
|
||||||
path := state.ToString(1)
|
|
||||||
|
|
||||||
fullPath, err := ResolvePath(path)
|
fullPath, err := ResolvePath(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("fs.read_file: " + err.Error())
|
return state.PushError("fs.read_file: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get file info for cache key and validation
|
// Get file info for cache key and validation
|
||||||
info, err := os.Stat(fullPath)
|
info, err := os.Stat(fullPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("fs.read_file: " + err.Error())
|
return state.PushError("fs.read_file: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create cache key with path and modification time
|
// Create cache key with path and modification time
|
||||||
@ -154,8 +155,7 @@ func fsReadFile(state *luajit.State) int {
|
|||||||
stats.misses++
|
stats.misses++
|
||||||
data, err := os.ReadFile(fullPath)
|
data, err := os.ReadFile(fullPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("fs.read_file: " + err.Error())
|
return state.PushError("fs.read_file: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compress and cache the data
|
// Compress and cache the data
|
||||||
@ -170,41 +170,33 @@ func fsReadFile(state *luajit.State) int {
|
|||||||
|
|
||||||
// fsWriteFile writes data to a file
|
// fsWriteFile writes data to a file
|
||||||
func fsWriteFile(state *luajit.State) int {
|
func fsWriteFile(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckExactArgs(2); err != nil {
|
||||||
state.PushString("fs.write_file: path must be a string")
|
return state.PushError("fs.write_file: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
path := state.ToString(1)
|
|
||||||
|
|
||||||
if !state.IsString(2) {
|
path, err := state.SafeToString(1)
|
||||||
state.PushString("fs.write_file: content must be a string")
|
if err != nil {
|
||||||
return -1
|
return state.PushError("fs.write_file: path must be string")
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := state.SafeToString(2)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("fs.write_file: content must be string")
|
||||||
}
|
}
|
||||||
content := state.ToString(2)
|
|
||||||
|
|
||||||
fullPath, err := ResolvePath(path)
|
fullPath, err := ResolvePath(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("fs.write_file: " + err.Error())
|
return state.PushError("fs.write_file: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the directory exists
|
// Ensure the directory exists
|
||||||
dir := filepath.Dir(fullPath)
|
dir := filepath.Dir(fullPath)
|
||||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
state.PushString("fs.write_file: failed to create directory: " + err.Error())
|
return state.PushError("fs.write_file: failed to create directory: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = os.WriteFile(fullPath, []byte(content), 0644)
|
if err := os.WriteFile(fullPath, []byte(content), 0644); err != nil {
|
||||||
if err != nil {
|
return state.PushError("fs.write_file: %v", err)
|
||||||
state.PushString("fs.write_file: " + err.Error())
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Invalidate cache entries for this file path
|
|
||||||
if fileCache != nil {
|
|
||||||
// We can't easily iterate through cache keys, so we'll let the cache
|
|
||||||
// naturally expire old entries when the file is read again
|
|
||||||
}
|
}
|
||||||
|
|
||||||
state.PushBoolean(true)
|
state.PushBoolean(true)
|
||||||
@ -213,42 +205,39 @@ func fsWriteFile(state *luajit.State) int {
|
|||||||
|
|
||||||
// fsAppendFile appends data to a file
|
// fsAppendFile appends data to a file
|
||||||
func fsAppendFile(state *luajit.State) int {
|
func fsAppendFile(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckExactArgs(2); err != nil {
|
||||||
state.PushString("fs.append_file: path must be a string")
|
return state.PushError("fs.append_file: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
path := state.ToString(1)
|
|
||||||
|
|
||||||
if !state.IsString(2) {
|
path, err := state.SafeToString(1)
|
||||||
state.PushString("fs.append_file: content must be a string")
|
if err != nil {
|
||||||
return -1
|
return state.PushError("fs.append_file: path must be string")
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := state.SafeToString(2)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("fs.append_file: content must be string")
|
||||||
}
|
}
|
||||||
content := state.ToString(2)
|
|
||||||
|
|
||||||
fullPath, err := ResolvePath(path)
|
fullPath, err := ResolvePath(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("fs.append_file: " + err.Error())
|
return state.PushError("fs.append_file: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the directory exists
|
// Ensure the directory exists
|
||||||
dir := filepath.Dir(fullPath)
|
dir := filepath.Dir(fullPath)
|
||||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
state.PushString("fs.append_file: failed to create directory: " + err.Error())
|
return state.PushError("fs.append_file: failed to create directory: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
file, err := os.OpenFile(fullPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
file, err := os.OpenFile(fullPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("fs.append_file: " + err.Error())
|
return state.PushError("fs.append_file: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
|
||||||
_, err = file.Write([]byte(content))
|
if _, err = file.Write([]byte(content)); err != nil {
|
||||||
if err != nil {
|
return state.PushError("fs.append_file: %v", err)
|
||||||
state.PushString("fs.append_file: " + err.Error())
|
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
state.PushBoolean(true)
|
state.PushBoolean(true)
|
||||||
@ -257,16 +246,18 @@ func fsAppendFile(state *luajit.State) int {
|
|||||||
|
|
||||||
// fsExists checks if a file or directory exists
|
// fsExists checks if a file or directory exists
|
||||||
func fsExists(state *luajit.State) int {
|
func fsExists(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckExactArgs(1); err != nil {
|
||||||
state.PushString("fs.exists: path must be a string")
|
return state.PushError("fs.exists: %v", err)
|
||||||
return -1
|
}
|
||||||
|
|
||||||
|
path, err := state.SafeToString(1)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("fs.exists: path must be string")
|
||||||
}
|
}
|
||||||
path := state.ToString(1)
|
|
||||||
|
|
||||||
fullPath, err := ResolvePath(path)
|
fullPath, err := ResolvePath(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("fs.exists: " + err.Error())
|
return state.PushError("fs.exists: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = os.Stat(fullPath)
|
_, err = os.Stat(fullPath)
|
||||||
@ -276,34 +267,32 @@ func fsExists(state *luajit.State) int {
|
|||||||
|
|
||||||
// fsRemoveFile removes a file
|
// fsRemoveFile removes a file
|
||||||
func fsRemoveFile(state *luajit.State) int {
|
func fsRemoveFile(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckExactArgs(1); err != nil {
|
||||||
state.PushString("fs.remove_file: path must be a string")
|
return state.PushError("fs.remove_file: %v", err)
|
||||||
return -1
|
}
|
||||||
|
|
||||||
|
path, err := state.SafeToString(1)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("fs.remove_file: path must be string")
|
||||||
}
|
}
|
||||||
path := state.ToString(1)
|
|
||||||
|
|
||||||
fullPath, err := ResolvePath(path)
|
fullPath, err := ResolvePath(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("fs.remove_file: " + err.Error())
|
return state.PushError("fs.remove_file: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if it's a directory
|
// Check if it's a directory
|
||||||
info, err := os.Stat(fullPath)
|
info, err := os.Stat(fullPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("fs.remove_file: " + err.Error())
|
return state.PushError("fs.remove_file: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if info.IsDir() {
|
if info.IsDir() {
|
||||||
state.PushString("fs.remove_file: cannot remove directory, use remove_dir instead")
|
return state.PushError("fs.remove_file: cannot remove directory, use remove_dir instead")
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = os.Remove(fullPath)
|
if err := os.Remove(fullPath); err != nil {
|
||||||
if err != nil {
|
return state.PushError("fs.remove_file: %v", err)
|
||||||
state.PushString("fs.remove_file: " + err.Error())
|
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
state.PushBoolean(true)
|
state.PushBoolean(true)
|
||||||
@ -312,67 +301,65 @@ func fsRemoveFile(state *luajit.State) int {
|
|||||||
|
|
||||||
// fsGetInfo gets information about a file
|
// fsGetInfo gets information about a file
|
||||||
func fsGetInfo(state *luajit.State) int {
|
func fsGetInfo(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckExactArgs(1); err != nil {
|
||||||
state.PushString("fs.get_info: path must be a string")
|
return state.PushError("fs.get_info: %v", err)
|
||||||
return -1
|
}
|
||||||
|
|
||||||
|
path, err := state.SafeToString(1)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("fs.get_info: path must be string")
|
||||||
}
|
}
|
||||||
path := state.ToString(1)
|
|
||||||
|
|
||||||
fullPath, err := ResolvePath(path)
|
fullPath, err := ResolvePath(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("fs.get_info: " + err.Error())
|
return state.PushError("fs.get_info: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
info, err := os.Stat(fullPath)
|
info, err := os.Stat(fullPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("fs.get_info: " + err.Error())
|
return state.PushError("fs.get_info: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
state.NewTable()
|
fileInfo := map[string]any{
|
||||||
|
"name": info.Name(),
|
||||||
|
"size": info.Size(),
|
||||||
|
"mode": int(info.Mode()),
|
||||||
|
"mod_time": info.ModTime().Unix(),
|
||||||
|
"is_dir": info.IsDir(),
|
||||||
|
}
|
||||||
|
|
||||||
state.PushString(info.Name())
|
if err := state.PushValue(fileInfo); err != nil {
|
||||||
state.SetField(-2, "name")
|
return state.PushError("fs.get_info: %v", err)
|
||||||
|
}
|
||||||
state.PushNumber(float64(info.Size()))
|
|
||||||
state.SetField(-2, "size")
|
|
||||||
|
|
||||||
state.PushNumber(float64(info.Mode()))
|
|
||||||
state.SetField(-2, "mode")
|
|
||||||
|
|
||||||
state.PushNumber(float64(info.ModTime().Unix()))
|
|
||||||
state.SetField(-2, "mod_time")
|
|
||||||
|
|
||||||
state.PushBoolean(info.IsDir())
|
|
||||||
state.SetField(-2, "is_dir")
|
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// fsMakeDir creates a directory
|
// fsMakeDir creates a directory
|
||||||
func fsMakeDir(state *luajit.State) int {
|
func fsMakeDir(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckMinArgs(1); err != nil {
|
||||||
state.PushString("fs.make_dir: path must be a string")
|
return state.PushError("fs.make_dir: %v", err)
|
||||||
return -1
|
}
|
||||||
|
|
||||||
|
path, err := state.SafeToString(1)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("fs.make_dir: path must be string")
|
||||||
}
|
}
|
||||||
path := state.ToString(1)
|
|
||||||
|
|
||||||
perm := os.FileMode(0755)
|
perm := os.FileMode(0755)
|
||||||
if state.GetTop() >= 2 && state.IsNumber(2) {
|
if state.GetTop() >= 2 {
|
||||||
perm = os.FileMode(state.ToNumber(2))
|
if permVal, err := state.SafeToNumber(2); err == nil {
|
||||||
|
perm = os.FileMode(permVal)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fullPath, err := ResolvePath(path)
|
fullPath, err := ResolvePath(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("fs.make_dir: " + err.Error())
|
return state.PushError("fs.make_dir: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = os.MkdirAll(fullPath, perm)
|
if err := os.MkdirAll(fullPath, perm); err != nil {
|
||||||
if err != nil {
|
return state.PushError("fs.make_dir: %v", err)
|
||||||
state.PushString("fs.make_dir: " + err.Error())
|
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
state.PushBoolean(true)
|
state.PushBoolean(true)
|
||||||
@ -381,41 +368,42 @@ func fsMakeDir(state *luajit.State) int {
|
|||||||
|
|
||||||
// fsListDir lists the contents of a directory
|
// fsListDir lists the contents of a directory
|
||||||
func fsListDir(state *luajit.State) int {
|
func fsListDir(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckExactArgs(1); err != nil {
|
||||||
state.PushString("fs.list_dir: path must be a string")
|
return state.PushError("fs.list_dir: %v", err)
|
||||||
return -1
|
}
|
||||||
|
|
||||||
|
path, err := state.SafeToString(1)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("fs.list_dir: path must be string")
|
||||||
}
|
}
|
||||||
path := state.ToString(1)
|
|
||||||
|
|
||||||
fullPath, err := ResolvePath(path)
|
fullPath, err := ResolvePath(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("fs.list_dir: " + err.Error())
|
return state.PushError("fs.list_dir: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
info, err := os.Stat(fullPath)
|
info, err := os.Stat(fullPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("fs.list_dir: " + err.Error())
|
return state.PushError("fs.list_dir: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !info.IsDir() {
|
if !info.IsDir() {
|
||||||
state.PushString("fs.list_dir: not a directory")
|
return state.PushError("fs.list_dir: not a directory")
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
files, err := os.ReadDir(fullPath)
|
files, err := os.ReadDir(fullPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("fs.list_dir: " + err.Error())
|
return state.PushError("fs.list_dir: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
state.NewTable()
|
// Create array of filenames
|
||||||
|
filenames := make([]string, len(files))
|
||||||
for i, file := range files {
|
for i, file := range files {
|
||||||
state.PushNumber(float64(i + 1))
|
filenames[i] = file.Name()
|
||||||
state.PushString(file.Name())
|
}
|
||||||
state.SetTable(-3)
|
|
||||||
|
if err := state.PushValue(filenames); err != nil {
|
||||||
|
return state.PushError("fs.list_dir: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
@ -423,11 +411,14 @@ func fsListDir(state *luajit.State) int {
|
|||||||
|
|
||||||
// fsRemoveDir removes a directory
|
// fsRemoveDir removes a directory
|
||||||
func fsRemoveDir(state *luajit.State) int {
|
func fsRemoveDir(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckMinArgs(1); err != nil {
|
||||||
state.PushString("fs.remove_dir: path must be a string")
|
return state.PushError("fs.remove_dir: %v", err)
|
||||||
return -1
|
}
|
||||||
|
|
||||||
|
path, err := state.SafeToString(1)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("fs.remove_dir: path must be string")
|
||||||
}
|
}
|
||||||
path := state.ToString(1)
|
|
||||||
|
|
||||||
recursive := false
|
recursive := false
|
||||||
if state.GetTop() >= 2 {
|
if state.GetTop() >= 2 {
|
||||||
@ -436,19 +427,16 @@ func fsRemoveDir(state *luajit.State) int {
|
|||||||
|
|
||||||
fullPath, err := ResolvePath(path)
|
fullPath, err := ResolvePath(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("fs.remove_dir: " + err.Error())
|
return state.PushError("fs.remove_dir: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
info, err := os.Stat(fullPath)
|
info, err := os.Stat(fullPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("fs.remove_dir: " + err.Error())
|
return state.PushError("fs.remove_dir: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !info.IsDir() {
|
if !info.IsDir() {
|
||||||
state.PushString("fs.remove_dir: not a directory")
|
return state.PushError("fs.remove_dir: not a directory")
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if recursive {
|
if recursive {
|
||||||
@ -458,8 +446,7 @@ func fsRemoveDir(state *luajit.State) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("fs.remove_dir: " + err.Error())
|
return state.PushError("fs.remove_dir: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
state.PushBoolean(true)
|
state.PushBoolean(true)
|
||||||
@ -468,19 +455,17 @@ func fsRemoveDir(state *luajit.State) int {
|
|||||||
|
|
||||||
// fsJoinPaths joins path components
|
// fsJoinPaths joins path components
|
||||||
func fsJoinPaths(state *luajit.State) int {
|
func fsJoinPaths(state *luajit.State) int {
|
||||||
nargs := state.GetTop()
|
if err := state.CheckMinArgs(1); err != nil {
|
||||||
if nargs < 1 {
|
return state.PushError("fs.join_paths: %v", err)
|
||||||
state.PushString("fs.join_paths: at least one path component required")
|
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
components := make([]string, nargs)
|
components := make([]string, state.GetTop())
|
||||||
for i := 1; i <= nargs; i++ {
|
for i := 1; i <= state.GetTop(); i++ {
|
||||||
if !state.IsString(i) {
|
comp, err := state.SafeToString(i)
|
||||||
state.PushString("fs.join_paths: all arguments must be strings")
|
if err != nil {
|
||||||
return -1
|
return state.PushError("fs.join_paths: all arguments must be strings")
|
||||||
}
|
}
|
||||||
components[i-1] = state.ToString(i)
|
components[i-1] = comp
|
||||||
}
|
}
|
||||||
|
|
||||||
result := filepath.Join(components...)
|
result := filepath.Join(components...)
|
||||||
@ -492,11 +477,14 @@ func fsJoinPaths(state *luajit.State) int {
|
|||||||
|
|
||||||
// fsDirName returns the directory portion of a path
|
// fsDirName returns the directory portion of a path
|
||||||
func fsDirName(state *luajit.State) int {
|
func fsDirName(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckExactArgs(1); err != nil {
|
||||||
state.PushString("fs.dir_name: path must be a string")
|
return state.PushError("fs.dir_name: %v", err)
|
||||||
return -1
|
}
|
||||||
|
|
||||||
|
path, err := state.SafeToString(1)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("fs.dir_name: path must be string")
|
||||||
}
|
}
|
||||||
path := state.ToString(1)
|
|
||||||
|
|
||||||
dir := filepath.Dir(path)
|
dir := filepath.Dir(path)
|
||||||
dir = strings.ReplaceAll(dir, "\\", "/")
|
dir = strings.ReplaceAll(dir, "\\", "/")
|
||||||
@ -507,28 +495,32 @@ func fsDirName(state *luajit.State) int {
|
|||||||
|
|
||||||
// fsBaseName returns the file name portion of a path
|
// fsBaseName returns the file name portion of a path
|
||||||
func fsBaseName(state *luajit.State) int {
|
func fsBaseName(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckExactArgs(1); err != nil {
|
||||||
state.PushString("fs.base_name: path must be a string")
|
return state.PushError("fs.base_name: %v", err)
|
||||||
return -1
|
}
|
||||||
|
|
||||||
|
path, err := state.SafeToString(1)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("fs.base_name: path must be string")
|
||||||
}
|
}
|
||||||
path := state.ToString(1)
|
|
||||||
|
|
||||||
base := filepath.Base(path)
|
base := filepath.Base(path)
|
||||||
|
|
||||||
state.PushString(base)
|
state.PushString(base)
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// fsExtension returns the file extension
|
// fsExtension returns the file extension
|
||||||
func fsExtension(state *luajit.State) int {
|
func fsExtension(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckExactArgs(1); err != nil {
|
||||||
state.PushString("fs.extension: path must be a string")
|
return state.PushError("fs.extension: %v", err)
|
||||||
return -1
|
}
|
||||||
|
|
||||||
|
path, err := state.SafeToString(1)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("fs.extension: path must be string")
|
||||||
}
|
}
|
||||||
path := state.ToString(1)
|
|
||||||
|
|
||||||
ext := filepath.Ext(path)
|
ext := filepath.Ext(path)
|
||||||
|
|
||||||
state.PushString(ext)
|
state.PushString(ext)
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
155
runner/http.go
155
runner/http.go
@ -94,25 +94,26 @@ func ApplyResponse(resp *Response, ctx *fasthttp.RequestCtx) {
|
|||||||
|
|
||||||
// httpRequest makes an HTTP request and returns the result to Lua
|
// httpRequest makes an HTTP request and returns the result to Lua
|
||||||
func httpRequest(state *luajit.State) int {
|
func httpRequest(state *luajit.State) int {
|
||||||
// Get method (required)
|
if err := state.CheckMinArgs(2); err != nil {
|
||||||
if !state.IsString(1) {
|
return state.PushError("http.client.request: %v", err)
|
||||||
state.PushString("http.client.request: method must be a string")
|
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
method := strings.ToUpper(state.ToString(1))
|
|
||||||
|
|
||||||
// Get URL (required)
|
// Get method and URL
|
||||||
if !state.IsString(2) {
|
method, err := state.SafeToString(1)
|
||||||
state.PushString("http.client.request: url must be a string")
|
if err != nil {
|
||||||
return -1
|
return state.PushError("http.client.request: method must be string")
|
||||||
|
}
|
||||||
|
method = strings.ToUpper(method)
|
||||||
|
|
||||||
|
urlStr, err := state.SafeToString(2)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("http.client.request: url must be string")
|
||||||
}
|
}
|
||||||
urlStr := state.ToString(2)
|
|
||||||
|
|
||||||
// Parse URL to check if it's valid
|
// Parse URL to check if it's valid
|
||||||
parsedURL, err := url.Parse(urlStr)
|
parsedURL, err := url.Parse(urlStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("Invalid URL: " + err.Error())
|
return state.PushError("Invalid URL: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get client configuration
|
// Get client configuration
|
||||||
@ -120,8 +121,7 @@ func httpRequest(state *luajit.State) int {
|
|||||||
|
|
||||||
// Check if remote connections are allowed
|
// Check if remote connections are allowed
|
||||||
if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") {
|
if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") {
|
||||||
state.PushString("Remote connections are not allowed")
|
return state.PushError("Remote connections are not allowed")
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use bytebufferpool for request and response
|
// Use bytebufferpool for request and response
|
||||||
@ -139,13 +139,13 @@ func httpRequest(state *luajit.State) int {
|
|||||||
if state.GetTop() >= 3 && !state.IsNil(3) {
|
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||||
if state.IsString(3) {
|
if state.IsString(3) {
|
||||||
// String body
|
// String body
|
||||||
req.SetBodyString(state.ToString(3))
|
bodyStr, _ := state.SafeToString(3)
|
||||||
|
req.SetBodyString(bodyStr)
|
||||||
} else if state.IsTable(3) {
|
} else if state.IsTable(3) {
|
||||||
// Table body - convert to JSON
|
// Table body - convert to JSON
|
||||||
luaTable, err := state.ToTable(3)
|
luaTable, err := state.SafeToTable(3)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("Failed to parse body table: " + err.Error())
|
return state.PushError("Failed to parse body table: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use bytebufferpool for JSON serialization
|
// Use bytebufferpool for JSON serialization
|
||||||
@ -153,42 +153,37 @@ func httpRequest(state *luajit.State) int {
|
|||||||
defer bytebufferpool.Put(buf)
|
defer bytebufferpool.Put(buf)
|
||||||
|
|
||||||
if err := json.NewEncoder(buf).Encode(luaTable); err != nil {
|
if err := json.NewEncoder(buf).Encode(luaTable); err != nil {
|
||||||
state.PushString("Failed to convert body to JSON: " + err.Error())
|
return state.PushError("Failed to convert body to JSON: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
req.SetBody(buf.Bytes())
|
req.SetBody(buf.Bytes())
|
||||||
req.Header.SetContentType("application/json")
|
req.Header.SetContentType("application/json")
|
||||||
} else {
|
} else {
|
||||||
state.PushString("Body must be a string or table")
|
return state.PushError("Body must be a string or table")
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process options (headers, timeout, etc.)
|
// Process options (headers, timeout, etc.)
|
||||||
timeout := config.DefaultTimeout
|
timeout := config.DefaultTimeout
|
||||||
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsTable(4) {
|
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsTable(4) {
|
||||||
// Process headers
|
// Process headers using ForEachTableKV
|
||||||
state.GetField(4, "headers")
|
if headers, ok := state.GetFieldTable(4, "headers"); ok {
|
||||||
if state.IsTable(-1) {
|
if headerMap, ok := headers.(map[string]string); ok {
|
||||||
// Iterate through headers
|
for name, value := range headerMap {
|
||||||
state.PushNil() // Start iteration
|
req.Header.Set(name, value)
|
||||||
for state.Next(-2) {
|
}
|
||||||
// Stack now has key at -2 and value at -1
|
} else if headerMapAny, ok := headers.(map[string]any); ok {
|
||||||
if state.IsString(-2) && state.IsString(-1) {
|
for name, value := range headerMapAny {
|
||||||
headerName := state.ToString(-2)
|
if valueStr, ok := value.(string); ok {
|
||||||
headerValue := state.ToString(-1)
|
req.Header.Set(name, valueStr)
|
||||||
req.Header.Set(headerName, headerValue)
|
}
|
||||||
}
|
}
|
||||||
state.Pop(1) // Pop value, leave key for next iteration
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
state.Pop(1) // Pop headers table
|
|
||||||
|
|
||||||
// Get timeout
|
// Get timeout
|
||||||
state.GetField(4, "timeout")
|
if timeoutVal := state.GetFieldNumber(4, "timeout", 0); timeoutVal > 0 {
|
||||||
if state.IsNumber(-1) {
|
requestTimeout := time.Duration(timeoutVal) * time.Second
|
||||||
requestTimeout := time.Duration(state.ToNumber(-1)) * time.Second
|
|
||||||
|
|
||||||
// Apply max timeout if configured
|
// Apply max timeout if configured
|
||||||
if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout {
|
if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout {
|
||||||
@ -197,38 +192,34 @@ func httpRequest(state *luajit.State) int {
|
|||||||
timeout = requestTimeout
|
timeout = requestTimeout
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
state.Pop(1) // Pop timeout
|
|
||||||
|
|
||||||
// Process query parameters
|
// Process query parameters
|
||||||
state.GetField(4, "query")
|
if query, ok := state.GetFieldTable(4, "query"); ok {
|
||||||
if state.IsTable(-1) {
|
|
||||||
// Create URL args
|
|
||||||
args := req.URI().QueryArgs()
|
args := req.URI().QueryArgs()
|
||||||
|
|
||||||
// Iterate through query params
|
if queryMap, ok := query.(map[string]string); ok {
|
||||||
state.PushNil() // Start iteration
|
for name, value := range queryMap {
|
||||||
for state.Next(-2) {
|
args.Add(name, value)
|
||||||
if state.IsString(-2) {
|
}
|
||||||
paramName := state.ToString(-2)
|
} else if queryMapAny, ok := query.(map[string]any); ok {
|
||||||
|
for name, value := range queryMapAny {
|
||||||
// Handle different value types
|
switch v := value.(type) {
|
||||||
if state.IsString(-1) {
|
case string:
|
||||||
args.Add(paramName, state.ToString(-1))
|
args.Add(name, v)
|
||||||
} else if state.IsNumber(-1) {
|
case int:
|
||||||
args.Add(paramName, strings.TrimRight(strings.TrimRight(
|
args.Add(name, fmt.Sprintf("%d", v))
|
||||||
state.ToString(-1), "0"), "."))
|
case float64:
|
||||||
} else if state.IsBoolean(-1) {
|
args.Add(name, strings.TrimRight(strings.TrimRight(fmt.Sprintf("%.6f", v), "0"), "."))
|
||||||
if state.ToBoolean(-1) {
|
case bool:
|
||||||
args.Add(paramName, "true")
|
if v {
|
||||||
|
args.Add(name, "true")
|
||||||
} else {
|
} else {
|
||||||
args.Add(paramName, "false")
|
args.Add(name, "false")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
state.Pop(1) // Pop value, leave key for next iteration
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
state.Pop(1) // Pop query table
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create context with timeout
|
// Create context with timeout
|
||||||
@ -242,26 +233,18 @@ func httpRequest(state *luajit.State) int {
|
|||||||
if errors.Is(err, fasthttp.ErrTimeout) {
|
if errors.Is(err, fasthttp.ErrTimeout) {
|
||||||
errStr = "Request timed out after " + timeout.String()
|
errStr = "Request timed out after " + timeout.String()
|
||||||
}
|
}
|
||||||
state.PushString(errStr)
|
return state.PushError("%s", errStr)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create response table
|
// Create response using TableBuilder
|
||||||
state.NewTable()
|
builder := state.NewTableBuilder()
|
||||||
|
|
||||||
// Set status code
|
// Set status code and text
|
||||||
state.PushNumber(float64(resp.StatusCode()))
|
builder.SetNumber("status", float64(resp.StatusCode()))
|
||||||
state.SetField(-2, "status")
|
builder.SetString("status_text", fasthttp.StatusMessage(resp.StatusCode()))
|
||||||
|
|
||||||
// Set status text
|
|
||||||
statusText := fasthttp.StatusMessage(resp.StatusCode())
|
|
||||||
state.PushString(statusText)
|
|
||||||
state.SetField(-2, "status_text")
|
|
||||||
|
|
||||||
// Set body
|
// Set body
|
||||||
var respBody []byte
|
var respBody []byte
|
||||||
|
|
||||||
// Apply size limits to response
|
|
||||||
if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize {
|
if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize {
|
||||||
// Make a limited copy
|
// Make a limited copy
|
||||||
respBody = make([]byte, config.MaxResponseSize)
|
respBody = make([]byte, config.MaxResponseSize)
|
||||||
@ -270,32 +253,28 @@ func httpRequest(state *luajit.State) int {
|
|||||||
respBody = resp.Body()
|
respBody = resp.Body()
|
||||||
}
|
}
|
||||||
|
|
||||||
state.PushString(string(respBody))
|
builder.SetString("body", string(respBody))
|
||||||
state.SetField(-2, "body")
|
|
||||||
|
|
||||||
// Parse body as JSON if content type is application/json
|
// Parse body as JSON if content type is application/json
|
||||||
contentType := string(resp.Header.ContentType())
|
contentType := string(resp.Header.ContentType())
|
||||||
if strings.Contains(contentType, "application/json") {
|
if strings.Contains(contentType, "application/json") {
|
||||||
var jsonData any
|
var jsonData any
|
||||||
if err := json.Unmarshal(respBody, &jsonData); err == nil {
|
if err := json.Unmarshal(respBody, &jsonData); err == nil {
|
||||||
if err := state.PushValue(jsonData); err == nil {
|
builder.SetTable("json", jsonData)
|
||||||
state.SetField(-2, "json")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set headers
|
// Set headers
|
||||||
state.NewTable()
|
headers := make(map[string]string)
|
||||||
resp.Header.VisitAll(func(key, value []byte) {
|
resp.Header.VisitAll(func(key, value []byte) {
|
||||||
state.PushString(string(value))
|
headers[string(key)] = string(value)
|
||||||
state.SetField(-2, string(key))
|
|
||||||
})
|
})
|
||||||
state.SetField(-2, "headers")
|
builder.SetTable("headers", headers)
|
||||||
|
|
||||||
// Create ok field (true if status code is 2xx)
|
// Create ok field (true if status code is 2xx)
|
||||||
state.PushBoolean(resp.StatusCode() >= 200 && resp.StatusCode() < 300)
|
builder.SetBool("ok", resp.StatusCode() >= 200 && resp.StatusCode() < 300)
|
||||||
state.SetField(-2, "ok")
|
|
||||||
|
|
||||||
|
builder.Build()
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -303,8 +282,10 @@ func httpRequest(state *luajit.State) int {
|
|||||||
func generateToken(state *luajit.State) int {
|
func generateToken(state *luajit.State) int {
|
||||||
// Get the length from the Lua arguments (default to 32)
|
// Get the length from the Lua arguments (default to 32)
|
||||||
length := 32
|
length := 32
|
||||||
if state.GetTop() >= 1 && state.IsNumber(1) {
|
if state.GetTop() >= 1 {
|
||||||
length = int(state.ToNumber(1))
|
if lengthVal, err := state.SafeToNumber(1); err == nil {
|
||||||
|
length = int(lengthVal)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enforce minimum length for security
|
// Enforce minimum length for security
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
package runner
|
package runner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
|
|
||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
"github.com/alexedwards/argon2id"
|
"github.com/alexedwards/argon2id"
|
||||||
)
|
)
|
||||||
@ -20,12 +18,14 @@ func RegisterPasswordFunctions(state *luajit.State) error {
|
|||||||
|
|
||||||
// passwordHash implements the Argon2id password hashing using alexedwards/argon2id
|
// passwordHash implements the Argon2id password hashing using alexedwards/argon2id
|
||||||
func passwordHash(state *luajit.State) int {
|
func passwordHash(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckMinArgs(1); err != nil {
|
||||||
state.PushString("password_hash error: expected string password")
|
return state.PushError("password_hash: %v", err)
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
password := state.ToString(1)
|
password, err := state.SafeToString(1)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("password_hash: password must be string")
|
||||||
|
}
|
||||||
|
|
||||||
params := &argon2id.Params{
|
params := &argon2id.Params{
|
||||||
Memory: 128 * 1024,
|
Memory: 128 * 1024,
|
||||||
@ -35,42 +35,32 @@ func passwordHash(state *luajit.State) int {
|
|||||||
KeyLength: 32,
|
KeyLength: 32,
|
||||||
}
|
}
|
||||||
|
|
||||||
if state.IsTable(2) {
|
if state.GetTop() >= 2 && state.IsTable(2) {
|
||||||
state.GetField(2, "memory")
|
// Use new field getters with validation
|
||||||
if state.IsNumber(-1) {
|
if memory := state.GetFieldNumber(2, "memory", 0); memory > 0 {
|
||||||
params.Memory = max(uint32(state.ToNumber(-1)), 8*1024)
|
params.Memory = max(uint32(memory), 8*1024)
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
state.GetField(2, "iterations")
|
if iterations := state.GetFieldNumber(2, "iterations", 0); iterations > 0 {
|
||||||
if state.IsNumber(-1) {
|
params.Iterations = max(uint32(iterations), 1)
|
||||||
params.Iterations = max(uint32(state.ToNumber(-1)), 1)
|
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
state.GetField(2, "parallelism")
|
if parallelism := state.GetFieldNumber(2, "parallelism", 0); parallelism > 0 {
|
||||||
if state.IsNumber(-1) {
|
params.Parallelism = max(uint8(parallelism), 1)
|
||||||
params.Parallelism = max(uint8(state.ToNumber(-1)), 1)
|
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
state.GetField(2, "salt_length")
|
if saltLength := state.GetFieldNumber(2, "salt_length", 0); saltLength > 0 {
|
||||||
if state.IsNumber(-1) {
|
params.SaltLength = max(uint32(saltLength), 8)
|
||||||
params.SaltLength = max(uint32(state.ToNumber(-1)), 8)
|
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
state.GetField(2, "key_length")
|
if keyLength := state.GetFieldNumber(2, "key_length", 0); keyLength > 0 {
|
||||||
if state.IsNumber(-1) {
|
params.KeyLength = max(uint32(keyLength), 16)
|
||||||
params.KeyLength = max(uint32(state.ToNumber(-1)), 16)
|
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
hash, err := argon2id.CreateHash(password, params)
|
hash, err := argon2id.CreateHash(password, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString(fmt.Sprintf("password_hash error: %v", err))
|
return state.PushError("password_hash: %v", err)
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
state.PushString(hash)
|
state.PushString(hash)
|
||||||
@ -79,13 +69,22 @@ func passwordHash(state *luajit.State) int {
|
|||||||
|
|
||||||
// passwordVerify verifies a password against a hash
|
// passwordVerify verifies a password against a hash
|
||||||
func passwordVerify(state *luajit.State) int {
|
func passwordVerify(state *luajit.State) int {
|
||||||
if !state.IsString(1) || !state.IsString(2) {
|
if err := state.CheckExactArgs(2); err != nil {
|
||||||
state.PushBoolean(false)
|
state.PushBoolean(false)
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
password := state.ToString(1)
|
password, err := state.SafeToString(1)
|
||||||
hash := state.ToString(2)
|
if err != nil {
|
||||||
|
state.PushBoolean(false)
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
hash, err := state.SafeToString(2)
|
||||||
|
if err != nil {
|
||||||
|
state.PushBoolean(false)
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
match, err := argon2id.ComparePasswordAndHash(password, hash)
|
match, err := argon2id.ComparePasswordAndHash(password, hash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -156,7 +156,7 @@ func (r *Runner) createState(index int) (*State, error) {
|
|||||||
logger.Debugf("Creating Lua state %d", index)
|
logger.Debugf("Creating Lua state %d", index)
|
||||||
}
|
}
|
||||||
|
|
||||||
L := luajit.New()
|
L := luajit.New(true) // Explicitly open standard libraries
|
||||||
if L == nil {
|
if L == nil {
|
||||||
return nil, errors.New("failed to create Lua state")
|
return nil, errors.New("failed to create Lua state")
|
||||||
}
|
}
|
||||||
|
@ -130,42 +130,30 @@ func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
|
|||||||
|
|
||||||
// Execute runs a Lua script in the sandbox with the given context
|
// Execute runs a Lua script in the sandbox with the given context
|
||||||
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
|
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
|
||||||
state.GetGlobal("__execute_script")
|
// Use CallGlobal for cleaner function calling
|
||||||
if !state.IsFunction(-1) {
|
results, err := state.CallGlobal("__execute_script",
|
||||||
state.Pop(1)
|
func() any {
|
||||||
return nil, ErrSandboxNotInitialized
|
if err := state.LoadBytecode(bytecode, "script"); err != nil {
|
||||||
}
|
return nil
|
||||||
|
}
|
||||||
|
return "loaded_script"
|
||||||
|
}(),
|
||||||
|
ctx.Values)
|
||||||
|
|
||||||
if err := state.LoadBytecode(bytecode, "script"); err != nil {
|
if err != nil {
|
||||||
state.Pop(1) // Pop the __execute_script function
|
|
||||||
return nil, fmt.Errorf("failed to load script: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Push context values
|
|
||||||
if err := state.PushTable(ctx.Values); err != nil {
|
|
||||||
state.Pop(2) // Pop bytecode and __execute_script
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute with 2 args, 1 result
|
|
||||||
if err := state.Call(2, 1); err != nil {
|
|
||||||
return nil, fmt.Errorf("script execution failed: %w", err)
|
return nil, fmt.Errorf("script execution failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := state.ToValue(-1)
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
response := NewResponse()
|
response := NewResponse()
|
||||||
if err == nil {
|
if len(results) > 0 {
|
||||||
response.Body = body
|
response.Body = results[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
extractHTTPResponseData(state, response)
|
extractHTTPResponseData(state, response)
|
||||||
|
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractResponseData pulls response info from the Lua state
|
// extractResponseData pulls response info from the Lua state using new API
|
||||||
func extractHTTPResponseData(state *luajit.State, response *Response) {
|
func extractHTTPResponseData(state *luajit.State, response *Response) {
|
||||||
state.GetGlobal("__http_response")
|
state.GetGlobal("__http_response")
|
||||||
if !state.IsTable(-1) {
|
if !state.IsTable(-1) {
|
||||||
@ -173,162 +161,113 @@ func extractHTTPResponseData(state *luajit.State, response *Response) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract status
|
// Use new field getters with defaults
|
||||||
state.GetField(-1, "status")
|
response.Status = int(state.GetFieldNumber(-1, "status", 200))
|
||||||
if state.IsNumber(-1) {
|
|
||||||
response.Status = int(state.ToNumber(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Extract headers
|
// Extract headers using ForEachTableKV
|
||||||
state.GetField(-1, "headers")
|
if headerTable, ok := state.GetFieldTable(-1, "headers"); ok {
|
||||||
if state.IsTable(-1) {
|
if headers, ok := headerTable.(map[string]any); ok {
|
||||||
state.PushNil() // Start iteration
|
for k, v := range headers {
|
||||||
for state.Next(-2) {
|
if str, ok := v.(string); ok {
|
||||||
if state.IsString(-2) && state.IsString(-1) {
|
response.Headers[k] = str
|
||||||
key := state.ToString(-2)
|
}
|
||||||
value := state.ToString(-1)
|
|
||||||
response.Headers[key] = value
|
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Extract cookies
|
// Extract cookies using ForEachArray
|
||||||
state.GetField(-1, "cookies")
|
state.GetField(-1, "cookies")
|
||||||
if state.IsTable(-1) {
|
if state.IsTable(-1) {
|
||||||
length := state.GetTableLength(-1)
|
state.ForEachArray(-1, func(i int, s *luajit.State) bool {
|
||||||
for i := 1; i <= length; i++ {
|
if s.IsTable(-1) {
|
||||||
state.PushNumber(float64(i))
|
extractCookie(s, response)
|
||||||
state.GetTable(-2)
|
|
||||||
|
|
||||||
if state.IsTable(-1) {
|
|
||||||
extractCookie(state, response)
|
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
return true
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
state.Pop(1)
|
||||||
|
|
||||||
// Extract metadata
|
// Extract metadata
|
||||||
state.GetField(-1, "metadata")
|
if metadata, ok := state.GetFieldTable(-1, "metadata"); ok {
|
||||||
if state.IsTable(-1) {
|
if metaMap, ok := metadata.(map[string]any); ok {
|
||||||
table, err := state.ToTable(-1)
|
maps.Copy(response.Metadata, metaMap)
|
||||||
if err == nil {
|
|
||||||
maps.Copy(response.Metadata, table)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Extract session data
|
// Extract session data
|
||||||
state.GetField(-1, "session")
|
if session, ok := state.GetFieldTable(-1, "session"); ok {
|
||||||
if state.IsTable(-1) {
|
if sessMap, ok := session.(map[string]any); ok {
|
||||||
table, err := state.ToTable(-1)
|
maps.Copy(response.SessionData, sessMap)
|
||||||
if err == nil {
|
|
||||||
maps.Copy(response.SessionData, table)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
state.Pop(1) // Pop __http_response
|
state.Pop(1) // Pop __http_response
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractCookie pulls cookie data from the current table on the stack
|
// extractCookie pulls cookie data from the current table on the stack using new API
|
||||||
func extractCookie(state *luajit.State, response *Response) {
|
func extractCookie(state *luajit.State, response *Response) {
|
||||||
cookie := fasthttp.AcquireCookie()
|
cookie := fasthttp.AcquireCookie()
|
||||||
|
|
||||||
// Get name (required)
|
// Use new field getters with defaults
|
||||||
state.GetField(-1, "name")
|
name := state.GetFieldString(-1, "name", "")
|
||||||
if !state.IsString(-1) {
|
if name == "" {
|
||||||
state.Pop(1)
|
|
||||||
fasthttp.ReleaseCookie(cookie)
|
fasthttp.ReleaseCookie(cookie)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cookie.SetKey(state.ToString(-1))
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get value
|
cookie.SetKey(name)
|
||||||
state.GetField(-1, "value")
|
cookie.SetValue(state.GetFieldString(-1, "value", ""))
|
||||||
if state.IsString(-1) {
|
cookie.SetPath(state.GetFieldString(-1, "path", "/"))
|
||||||
cookie.SetValue(state.ToString(-1))
|
cookie.SetDomain(state.GetFieldString(-1, "domain", ""))
|
||||||
}
|
cookie.SetHTTPOnly(state.GetFieldBool(-1, "http_only", false))
|
||||||
state.Pop(1)
|
cookie.SetSecure(state.GetFieldBool(-1, "secure", false))
|
||||||
|
cookie.SetMaxAge(int(state.GetFieldNumber(-1, "max_age", 0)))
|
||||||
// Get path
|
|
||||||
state.GetField(-1, "path")
|
|
||||||
if state.IsString(-1) {
|
|
||||||
cookie.SetPath(state.ToString(-1))
|
|
||||||
} else {
|
|
||||||
cookie.SetPath("/") // Default
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get domain
|
|
||||||
state.GetField(-1, "domain")
|
|
||||||
if state.IsString(-1) {
|
|
||||||
cookie.SetDomain(state.ToString(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
// Get other parameters
|
|
||||||
state.GetField(-1, "http_only")
|
|
||||||
if state.IsBoolean(-1) {
|
|
||||||
cookie.SetHTTPOnly(state.ToBoolean(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
state.GetField(-1, "secure")
|
|
||||||
if state.IsBoolean(-1) {
|
|
||||||
cookie.SetSecure(state.ToBoolean(-1))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
state.GetField(-1, "max_age")
|
|
||||||
if state.IsNumber(-1) {
|
|
||||||
cookie.SetMaxAge(int(state.ToNumber(-1)))
|
|
||||||
}
|
|
||||||
state.Pop(1)
|
|
||||||
|
|
||||||
response.Cookies = append(response.Cookies, cookie)
|
response.Cookies = append(response.Cookies, cookie)
|
||||||
}
|
}
|
||||||
|
|
||||||
// jsonMarshal converts a Lua value to a JSON string
|
// jsonMarshal converts a Lua value to a JSON string with validation
|
||||||
func jsonMarshal(state *luajit.State) int {
|
func jsonMarshal(state *luajit.State) int {
|
||||||
value, err := state.ToValue(1)
|
if err := state.CheckExactArgs(1); err != nil {
|
||||||
|
return state.PushError("json marshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err := state.SafeToTable(1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString(fmt.Sprintf("json marshal error: %v", err))
|
// Try as generic value if not a table
|
||||||
return 1
|
value, err = state.ToValue(1)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("json marshal error: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bytes, err := json.Marshal(value)
|
bytes, err := json.Marshal(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString(fmt.Sprintf("json marshal error: %v", err))
|
return state.PushError("json marshal error: %v", err)
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
state.PushString(string(bytes))
|
state.PushString(string(bytes))
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// jsonUnmarshal converts a JSON string to a Lua value
|
// jsonUnmarshal converts a JSON string to a Lua value with validation
|
||||||
func jsonUnmarshal(state *luajit.State) int {
|
func jsonUnmarshal(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckExactArgs(1); err != nil {
|
||||||
state.PushString("json unmarshal error: expected string")
|
return state.PushError("json unmarshal: %v", err)
|
||||||
return 1
|
}
|
||||||
|
|
||||||
|
jsonStr, err := state.SafeToString(1)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("json unmarshal: expected string, got %s", state.GetType(1))
|
||||||
}
|
}
|
||||||
jsonStr := state.ToString(1)
|
|
||||||
|
|
||||||
var value any
|
var value any
|
||||||
err := json.Unmarshal([]byte(jsonStr), &value)
|
if err := json.Unmarshal([]byte(jsonStr), &value); err != nil {
|
||||||
if err != nil {
|
return state.PushError("json unmarshal error: %v", err)
|
||||||
state.PushString(fmt.Sprintf("json unmarshal error: %v", err))
|
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := state.PushValue(value); err != nil {
|
if err := state.PushValue(value); err != nil {
|
||||||
state.PushString(fmt.Sprintf("json unmarshal error: %v", err))
|
return state.PushError("json unmarshal error: %v", err)
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
118
runner/sqlite.go
118
runner/sqlite.go
@ -111,20 +111,24 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
|
|||||||
|
|
||||||
// sqlQuery executes a SQL query and returns results
|
// sqlQuery executes a SQL query and returns results
|
||||||
func sqlQuery(state *luajit.State) int {
|
func sqlQuery(state *luajit.State) int {
|
||||||
// Get required parameters
|
if err := state.CheckMinArgs(2); err != nil {
|
||||||
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
|
return state.PushError("sqlite.query: %v", err)
|
||||||
state.PushString("sqlite.query: requires database name and query")
|
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dbName := state.ToString(1)
|
dbName, err := state.SafeToString(1)
|
||||||
query := state.ToString(2)
|
if err != nil {
|
||||||
|
return state.PushError("sqlite.query: database name must be string")
|
||||||
|
}
|
||||||
|
|
||||||
|
query, err := state.SafeToString(2)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("sqlite.query: query must be string")
|
||||||
|
}
|
||||||
|
|
||||||
// Get pool
|
// Get pool
|
||||||
pool, err := getPool(dbName)
|
pool, err := getPool(dbName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
return state.PushError("sqlite.query: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get connection with timeout
|
// Get connection with timeout
|
||||||
@ -133,8 +137,7 @@ func sqlQuery(state *luajit.State) int {
|
|||||||
|
|
||||||
conn, err := pool.Take(ctx)
|
conn, err := pool.Take(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.query: connection timeout: %s", err.Error()))
|
return state.PushError("sqlite.query: connection timeout: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
defer pool.Put(conn)
|
defer pool.Put(conn)
|
||||||
|
|
||||||
@ -145,8 +148,7 @@ func sqlQuery(state *luajit.State) int {
|
|||||||
// Set up parameters if provided
|
// Set up parameters if provided
|
||||||
if state.GetTop() >= 3 && !state.IsNil(3) {
|
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||||
if err := setupParams(state, 3, &execOpts); err != nil {
|
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
return state.PushError("sqlite.query: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -182,19 +184,12 @@ func sqlQuery(state *luajit.State) int {
|
|||||||
|
|
||||||
// Execute query
|
// Execute query
|
||||||
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
return state.PushError("sqlite.query: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create result table
|
// Create result using specific map type and PushValue
|
||||||
state.NewTable()
|
if err := state.PushValue(rows); err != nil {
|
||||||
for i, row := range rows {
|
return state.PushError("sqlite.query: %v", err)
|
||||||
state.PushNumber(float64(i + 1))
|
|
||||||
if err := state.PushTable(row); err != nil {
|
|
||||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
state.SetTable(-3)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
@ -202,20 +197,24 @@ func sqlQuery(state *luajit.State) int {
|
|||||||
|
|
||||||
// sqlExec executes a SQL statement without returning results
|
// sqlExec executes a SQL statement without returning results
|
||||||
func sqlExec(state *luajit.State) int {
|
func sqlExec(state *luajit.State) int {
|
||||||
// Get required parameters
|
if err := state.CheckMinArgs(2); err != nil {
|
||||||
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
|
return state.PushError("sqlite.exec: %v", err)
|
||||||
state.PushString("sqlite.exec: requires database name and query")
|
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dbName := state.ToString(1)
|
dbName, err := state.SafeToString(1)
|
||||||
query := state.ToString(2)
|
if err != nil {
|
||||||
|
return state.PushError("sqlite.exec: database name must be string")
|
||||||
|
}
|
||||||
|
|
||||||
|
query, err := state.SafeToString(2)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("sqlite.exec: query must be string")
|
||||||
|
}
|
||||||
|
|
||||||
// Get pool
|
// Get pool
|
||||||
pool, err := getPool(dbName)
|
pool, err := getPool(dbName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
return state.PushError("sqlite.exec: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get connection with timeout
|
// Get connection with timeout
|
||||||
@ -224,8 +223,7 @@ func sqlExec(state *luajit.State) int {
|
|||||||
|
|
||||||
conn, err := pool.Take(ctx)
|
conn, err := pool.Take(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.exec: connection timeout: %s", err.Error()))
|
return state.PushError("sqlite.exec: connection timeout: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
defer pool.Put(conn)
|
defer pool.Put(conn)
|
||||||
|
|
||||||
@ -235,8 +233,7 @@ func sqlExec(state *luajit.State) int {
|
|||||||
// Fast path for multi-statement scripts
|
// Fast path for multi-statement scripts
|
||||||
if strings.Contains(query, ";") && !hasParams {
|
if strings.Contains(query, ";") && !hasParams {
|
||||||
if err := sqlitex.ExecScript(conn, query); err != nil {
|
if err := sqlitex.ExecScript(conn, query); err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
return state.PushError("sqlite.exec: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
state.PushNumber(float64(conn.Changes()))
|
state.PushNumber(float64(conn.Changes()))
|
||||||
return 1
|
return 1
|
||||||
@ -245,8 +242,7 @@ func sqlExec(state *luajit.State) int {
|
|||||||
// Fast path for simple queries with no parameters
|
// Fast path for simple queries with no parameters
|
||||||
if !hasParams {
|
if !hasParams {
|
||||||
if err := sqlitex.Execute(conn, query, nil); err != nil {
|
if err := sqlitex.Execute(conn, query, nil); err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
return state.PushError("sqlite.exec: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
state.PushNumber(float64(conn.Changes()))
|
state.PushNumber(float64(conn.Changes()))
|
||||||
return 1
|
return 1
|
||||||
@ -255,14 +251,12 @@ func sqlExec(state *luajit.State) int {
|
|||||||
// Create execution options for parameterized query
|
// Create execution options for parameterized query
|
||||||
var execOpts sqlitex.ExecOptions
|
var execOpts sqlitex.ExecOptions
|
||||||
if err := setupParams(state, 3, &execOpts); err != nil {
|
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
return state.PushError("sqlite.exec: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute with parameters
|
// Execute with parameters
|
||||||
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
return state.PushError("sqlite.exec: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return affected rows
|
// Return affected rows
|
||||||
@ -273,11 +267,17 @@ func sqlExec(state *luajit.State) int {
|
|||||||
// setupParams configures execution options with parameters from Lua
|
// setupParams configures execution options with parameters from Lua
|
||||||
func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error {
|
func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error {
|
||||||
if state.IsTable(paramIndex) {
|
if state.IsTable(paramIndex) {
|
||||||
params, err := state.ToTable(paramIndex)
|
paramsAny, err := state.SafeToTable(paramIndex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid parameters: %w", err)
|
return fmt.Errorf("invalid parameters: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Type assert to map[string]any
|
||||||
|
params, ok := paramsAny.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("parameters must be a table")
|
||||||
|
}
|
||||||
|
|
||||||
// Check for array-style params
|
// Check for array-style params
|
||||||
if arr, ok := params[""]; ok {
|
if arr, ok := params[""]; ok {
|
||||||
if arrParams, ok := arr.([]any); ok {
|
if arrParams, ok := arr.([]any); ok {
|
||||||
@ -321,20 +321,24 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti
|
|||||||
|
|
||||||
// sqlGetOne executes a query and returns only the first row
|
// sqlGetOne executes a query and returns only the first row
|
||||||
func sqlGetOne(state *luajit.State) int {
|
func sqlGetOne(state *luajit.State) int {
|
||||||
// Get required parameters
|
if err := state.CheckMinArgs(2); err != nil {
|
||||||
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
|
return state.PushError("sqlite.get_one: %v", err)
|
||||||
state.PushString("sqlite.get_one: requires database name and query")
|
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dbName := state.ToString(1)
|
dbName, err := state.SafeToString(1)
|
||||||
query := state.ToString(2)
|
if err != nil {
|
||||||
|
return state.PushError("sqlite.get_one: database name must be string")
|
||||||
|
}
|
||||||
|
|
||||||
|
query, err := state.SafeToString(2)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("sqlite.get_one: query must be string")
|
||||||
|
}
|
||||||
|
|
||||||
// Get pool
|
// Get pool
|
||||||
pool, err := getPool(dbName)
|
pool, err := getPool(dbName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
|
return state.PushError("sqlite.get_one: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get connection with timeout
|
// Get connection with timeout
|
||||||
@ -343,8 +347,7 @@ func sqlGetOne(state *luajit.State) int {
|
|||||||
|
|
||||||
conn, err := pool.Take(ctx)
|
conn, err := pool.Take(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.get_one: connection timeout: %s", err.Error()))
|
return state.PushError("sqlite.get_one: connection timeout: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
defer pool.Put(conn)
|
defer pool.Put(conn)
|
||||||
|
|
||||||
@ -355,8 +358,7 @@ func sqlGetOne(state *luajit.State) int {
|
|||||||
// Set up parameters if provided
|
// Set up parameters if provided
|
||||||
if state.GetTop() >= 3 && !state.IsNil(3) {
|
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||||
if err := setupParams(state, 3, &execOpts); err != nil {
|
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
|
return state.PushError("sqlite.get_one: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -395,17 +397,15 @@ func sqlGetOne(state *luajit.State) int {
|
|||||||
|
|
||||||
// Execute query
|
// Execute query
|
||||||
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
|
return state.PushError("sqlite.get_one: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return result or nil if no rows
|
// Return result or nil if no rows
|
||||||
if result == nil {
|
if result == nil {
|
||||||
state.PushNil()
|
state.PushNil()
|
||||||
} else {
|
} else {
|
||||||
if err := state.PushTable(result); err != nil {
|
if err := state.PushValue(result); err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
|
return state.PushError("sqlite.get_one: %v", err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,12 +35,17 @@ func RegisterUtilFunctions(state *luajit.State) error {
|
|||||||
|
|
||||||
// htmlSpecialChars converts special characters to HTML entities
|
// htmlSpecialChars converts special characters to HTML entities
|
||||||
func htmlSpecialChars(state *luajit.State) int {
|
func htmlSpecialChars(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckExactArgs(1); err != nil {
|
||||||
|
state.PushNil()
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
input, err := state.SafeToString(1)
|
||||||
|
if err != nil {
|
||||||
state.PushNil()
|
state.PushNil()
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
input := state.ToString(1)
|
|
||||||
result := html.EscapeString(input)
|
result := html.EscapeString(input)
|
||||||
state.PushString(result)
|
state.PushString(result)
|
||||||
return 1
|
return 1
|
||||||
@ -48,12 +53,17 @@ func htmlSpecialChars(state *luajit.State) int {
|
|||||||
|
|
||||||
// htmlEntities is a more comprehensive version of htmlSpecialChars
|
// htmlEntities is a more comprehensive version of htmlSpecialChars
|
||||||
func htmlEntities(state *luajit.State) int {
|
func htmlEntities(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckExactArgs(1); err != nil {
|
||||||
|
state.PushNil()
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
input, err := state.SafeToString(1)
|
||||||
|
if err != nil {
|
||||||
state.PushNil()
|
state.PushNil()
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
input := state.ToString(1)
|
|
||||||
// First use HTML escape for standard entities
|
// First use HTML escape for standard entities
|
||||||
result := html.EscapeString(input)
|
result := html.EscapeString(input)
|
||||||
|
|
||||||
@ -86,12 +96,17 @@ func htmlEntities(state *luajit.State) int {
|
|||||||
|
|
||||||
// base64Encode encodes a string to base64
|
// base64Encode encodes a string to base64
|
||||||
func base64Encode(state *luajit.State) int {
|
func base64Encode(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckExactArgs(1); err != nil {
|
||||||
|
state.PushNil()
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
input, err := state.SafeToString(1)
|
||||||
|
if err != nil {
|
||||||
state.PushNil()
|
state.PushNil()
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
input := state.ToString(1)
|
|
||||||
result := base64.StdEncoding.EncodeToString([]byte(input))
|
result := base64.StdEncoding.EncodeToString([]byte(input))
|
||||||
state.PushString(result)
|
state.PushString(result)
|
||||||
return 1
|
return 1
|
||||||
@ -99,12 +114,17 @@ func base64Encode(state *luajit.State) int {
|
|||||||
|
|
||||||
// base64Decode decodes a base64 string
|
// base64Decode decodes a base64 string
|
||||||
func base64Decode(state *luajit.State) int {
|
func base64Decode(state *luajit.State) int {
|
||||||
if !state.IsString(1) {
|
if err := state.CheckExactArgs(1); err != nil {
|
||||||
|
state.PushNil()
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
input, err := state.SafeToString(1)
|
||||||
|
if err != nil {
|
||||||
state.PushNil()
|
state.PushNil()
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
input := state.ToString(1)
|
|
||||||
result, err := base64.StdEncoding.DecodeString(input)
|
result, err := base64.StdEncoding.DecodeString(input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushNil()
|
state.PushNil()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user