migrate to new LJTG API

This commit is contained in:
Sky Johnson 2025-06-02 11:15:56 -05:00
parent cc6a7675d8
commit e2b1b932ff
10 changed files with 514 additions and 558 deletions

View File

@ -78,16 +78,21 @@ func CleanupCrypto(state *luajit.State) {
// cryptoHash generates hash digests using various algorithms // cryptoHash generates hash digests using various algorithms
func cryptoHash(state *luajit.State) int { func cryptoHash(state *luajit.State) int {
if !state.IsString(1) || !state.IsString(2) { if err := state.CheckMinArgs(2); err != nil {
state.PushString("hash: expected (string data, string algorithm)") return state.PushError("hash: %v", err)
return 1
} }
data := state.ToString(1) data, err := state.SafeToString(1)
algorithm := state.ToString(2) if err != nil {
return state.PushError("hash: data must be string")
}
algorithm, err := state.SafeToString(2)
if err != nil {
return state.PushError("hash: algorithm must be string")
}
var h hash.Hash var h hash.Hash
switch algorithm { switch algorithm {
case "md5": case "md5":
h = md5.New() h = md5.New()
@ -98,8 +103,7 @@ func cryptoHash(state *luajit.State) int {
case "sha512": case "sha512":
h = sha512.New() h = sha512.New()
default: default:
state.PushString(fmt.Sprintf("unsupported algorithm: %s", algorithm)) return state.PushError("unsupported algorithm: %s", algorithm)
return 1
} }
h.Write([]byte(data)) h.Write([]byte(data))
@ -107,8 +111,10 @@ func cryptoHash(state *luajit.State) int {
// Output format // Output format
outputFormat := "hex" outputFormat := "hex"
if state.GetTop() >= 3 && state.IsString(3) { if state.GetTop() >= 3 {
outputFormat = state.ToString(3) if format, err := state.SafeToString(3); err == nil {
outputFormat = format
}
} }
switch outputFormat { switch outputFormat {
@ -125,17 +131,26 @@ func cryptoHash(state *luajit.State) int {
// cryptoHmac generates HMAC using various hash algorithms // cryptoHmac generates HMAC using various hash algorithms
func cryptoHmac(state *luajit.State) int { func cryptoHmac(state *luajit.State) int {
if !state.IsString(1) || !state.IsString(2) || !state.IsString(3) { if err := state.CheckMinArgs(3); err != nil {
state.PushString("hmac: expected (string data, string key, string algorithm)") return state.PushError("hmac: %v", err)
return 1
} }
data := state.ToString(1) data, err := state.SafeToString(1)
key := state.ToString(2) if err != nil {
algorithm := state.ToString(3) return state.PushError("hmac: data must be string")
}
key, err := state.SafeToString(2)
if err != nil {
return state.PushError("hmac: key must be string")
}
algorithm, err := state.SafeToString(3)
if err != nil {
return state.PushError("hmac: algorithm must be string")
}
var h func() hash.Hash var h func() hash.Hash
switch algorithm { switch algorithm {
case "md5": case "md5":
h = md5.New h = md5.New
@ -146,8 +161,7 @@ func cryptoHmac(state *luajit.State) int {
case "sha512": case "sha512":
h = sha512.New h = sha512.New
default: default:
state.PushString(fmt.Sprintf("unsupported algorithm: %s", algorithm)) return state.PushError("unsupported algorithm: %s", algorithm)
return 1
} }
mac := hmac.New(h, []byte(key)) mac := hmac.New(h, []byte(key))
@ -156,8 +170,10 @@ func cryptoHmac(state *luajit.State) int {
// Output format // Output format
outputFormat := "hex" outputFormat := "hex"
if state.GetTop() >= 4 && state.IsString(4) { if state.GetTop() >= 4 {
outputFormat = state.ToString(4) if format, err := state.SafeToString(4); err == nil {
outputFormat = format
}
} }
switch outputFormat { switch outputFormat {
@ -175,10 +191,8 @@ func cryptoHmac(state *luajit.State) int {
// cryptoUuid generates a random UUID v4 // cryptoUuid generates a random UUID v4
func cryptoUuid(state *luajit.State) int { func cryptoUuid(state *luajit.State) int {
uuid := make([]byte, 16) uuid := make([]byte, 16)
_, err := rand.Read(uuid) if _, err := rand.Read(uuid); err != nil {
if err != nil { return state.PushError("uuid: generation error: %v", err)
state.PushString(fmt.Sprintf("uuid: generation error: %v", err))
return 1
} }
// Set version (4) and variant (RFC 4122) // Set version (4) and variant (RFC 4122)
@ -194,15 +208,17 @@ func cryptoUuid(state *luajit.State) int {
// cryptoRandomBytes generates random bytes // cryptoRandomBytes generates random bytes
func cryptoRandomBytes(state *luajit.State) int { func cryptoRandomBytes(state *luajit.State) int {
if !state.IsNumber(1) { if err := state.CheckMinArgs(1); err != nil {
state.PushString("random_bytes: expected (number length)") return state.PushError("random_bytes: %v", err)
return 1 }
length, err := state.SafeToNumber(1)
if err != nil {
return state.PushError("random_bytes: length must be number")
} }
length := int(state.ToNumber(1))
if length <= 0 { if length <= 0 {
state.PushString("random_bytes: length must be positive") return state.PushError("random_bytes: length must be positive")
return 1
} }
// Check if secure // Check if secure
@ -211,13 +227,11 @@ func cryptoRandomBytes(state *luajit.State) int {
secure = state.ToBoolean(2) secure = state.ToBoolean(2)
} }
bytes := make([]byte, length) bytes := make([]byte, int(length))
if secure { if secure {
_, err := rand.Read(bytes) if _, err := rand.Read(bytes); err != nil {
if err != nil { return state.PushError("random_bytes: error: %v", err)
state.PushString(fmt.Sprintf("random_bytes: error: %v", err))
return 1
} }
} else { } else {
stateRngsMu.Lock() stateRngsMu.Lock()
@ -225,8 +239,7 @@ func cryptoRandomBytes(state *luajit.State) int {
stateRngsMu.Unlock() stateRngsMu.Unlock()
if !ok { if !ok {
state.PushString("random_bytes: RNG not initialized") return state.PushError("random_bytes: RNG not initialized")
return 1
} }
for i := range bytes { for i := range bytes {
@ -236,8 +249,10 @@ func cryptoRandomBytes(state *luajit.State) int {
// Output format // Output format
outputFormat := "binary" outputFormat := "binary"
if state.GetTop() >= 3 && state.IsString(3) { if state.GetTop() >= 3 {
outputFormat = state.ToString(3) if format, err := state.SafeToString(3); err == nil {
outputFormat = format
}
} }
switch outputFormat { switch outputFormat {
@ -254,17 +269,25 @@ func cryptoRandomBytes(state *luajit.State) int {
// cryptoRandomInt generates a random integer in range [min, max] // cryptoRandomInt generates a random integer in range [min, max]
func cryptoRandomInt(state *luajit.State) int { func cryptoRandomInt(state *luajit.State) int {
if !state.IsNumber(1) || !state.IsNumber(2) { if err := state.CheckMinArgs(2); err != nil {
state.PushString("random_int: expected (number min, number max)") return state.PushError("random_int: %v", err)
return 1
} }
min := int64(state.ToNumber(1)) minVal, err := state.SafeToNumber(1)
max := int64(state.ToNumber(2)) if err != nil {
return state.PushError("random_int: min must be number")
}
maxVal, err := state.SafeToNumber(2)
if err != nil {
return state.PushError("random_int: max must be number")
}
min := int64(minVal)
max := int64(maxVal)
if max <= min { if max <= min {
state.PushString("random_int: max must be greater than min") return state.PushError("random_int: max must be greater than min")
return 1
} }
// Check if secure // Check if secure
@ -274,15 +297,12 @@ func cryptoRandomInt(state *luajit.State) int {
} }
range_size := max - min + 1 range_size := max - min + 1
var result int64 var result int64
if secure { if secure {
bytes := make([]byte, 8) bytes := make([]byte, 8)
_, err := rand.Read(bytes) if _, err := rand.Read(bytes); err != nil {
if err != nil { return state.PushError("random_int: error: %v", err)
state.PushString(fmt.Sprintf("random_int: error: %v", err))
return 1
} }
val := binary.BigEndian.Uint64(bytes) val := binary.BigEndian.Uint64(bytes)
@ -293,8 +313,7 @@ func cryptoRandomInt(state *luajit.State) int {
stateRngsMu.Unlock() stateRngsMu.Unlock()
if !ok { if !ok {
state.PushString("random_int: RNG not initialized") return state.PushError("random_int: RNG not initialized")
return 1
} }
result = min + int64(stateRng.Uint64()%uint64(range_size)) result = min + int64(stateRng.Uint64()%uint64(range_size))
@ -315,10 +334,8 @@ func cryptoRandom(state *luajit.State) int {
if numArgs == 0 { if numArgs == 0 {
if secure { if secure {
bytes := make([]byte, 8) bytes := make([]byte, 8)
_, err := rand.Read(bytes) if _, err := rand.Read(bytes); err != nil {
if err != nil { return state.PushError("random: error: %v", err)
state.PushString(fmt.Sprintf("random: error: %v", err))
return 1
} }
val := binary.BigEndian.Uint64(bytes) val := binary.BigEndian.Uint64(bytes)
state.PushNumber(float64(val) / float64(math.MaxUint64)) state.PushNumber(float64(val) / float64(math.MaxUint64))
@ -328,8 +345,7 @@ func cryptoRandom(state *luajit.State) int {
stateRngsMu.Unlock() stateRngsMu.Unlock()
if !ok { if !ok {
state.PushString("random: RNG not initialized") return state.PushError("random: RNG not initialized")
return 1
} }
state.PushNumber(float64(stateRng.Uint64()) / float64(math.MaxUint64)) state.PushNumber(float64(stateRng.Uint64()) / float64(math.MaxUint64))
@ -341,8 +357,7 @@ func cryptoRandom(state *luajit.State) int {
if numArgs == 1 && state.IsNumber(1) { if numArgs == 1 && state.IsNumber(1) {
n := int64(state.ToNumber(1)) n := int64(state.ToNumber(1))
if n < 1 { if n < 1 {
state.PushString("random: upper bound must be >= 1") return state.PushError("random: upper bound must be >= 1")
return 1
} }
state.PushNumber(1) // min state.PushNumber(1) // min
@ -357,21 +372,23 @@ func cryptoRandom(state *luajit.State) int {
return cryptoRandomInt(state) return cryptoRandomInt(state)
} }
state.PushString("random: invalid arguments") return state.PushError("random: invalid arguments")
return 1
} }
// cryptoRandomSeed sets seed for non-secure RNG // cryptoRandomSeed sets seed for non-secure RNG
func cryptoRandomSeed(state *luajit.State) int { func cryptoRandomSeed(state *luajit.State) int {
if !state.IsNumber(1) { if err := state.CheckExactArgs(1); err != nil {
state.PushString("randomseed: expected (number seed)") return state.PushError("randomseed: %v", err)
return 1
} }
seed := uint64(state.ToNumber(1)) seed, err := state.SafeToNumber(1)
if err != nil {
return state.PushError("randomseed: seed must be number")
}
seedVal := uint64(seed)
stateRngsMu.Lock() stateRngsMu.Lock()
stateRngs[state] = mrand.NewPCG(seed, seed>>32) stateRngs[state] = mrand.NewPCG(seedVal, seedVal>>32)
stateRngsMu.Unlock() stateRngsMu.Unlock()
return 0 return 0

View File

@ -71,7 +71,7 @@ var (
// precompileModule compiles a module's code to bytecode once // precompileModule compiles a module's code to bytecode once
func precompileModule(m *ModuleInfo) { func precompileModule(m *ModuleInfo) {
m.Once.Do(func() { m.Once.Do(func() {
tempState := luajit.New() tempState := luajit.New(true) // Explicitly open standard libraries
if tempState == nil { if tempState == nil {
logger.Fatalf("Failed to create temp Lua state for %s module compilation", m.Name) logger.Fatalf("Failed to create temp Lua state for %s module compilation", m.Name)
return return
@ -105,18 +105,14 @@ func loadModule(state *luajit.State, m *ModuleInfo, verbose bool) error {
logger.Debugf("Loading %s.lua from precompiled bytecode", m.Name) logger.Debugf("Loading %s.lua from precompiled bytecode", m.Name)
} }
if err := state.LoadBytecode(*bytecode, m.Name+".lua"); err != nil {
return err
}
if m.DefinesGlobal { if m.DefinesGlobal {
// Module defines its own globals, just run it // Module defines its own globals, just run it
if err := state.RunBytecode(); err != nil { if err := state.LoadAndRunBytecode(*bytecode, m.Name+".lua"); err != nil {
return err return err
} }
} else { } else {
// Module returns a table, capture and set as global // Module returns a table, capture and set as global
if err := state.RunBytecodeWithResults(1); err != nil { if err := state.LoadAndRunBytecodeWithResults(*bytecode, m.Name+".lua", 1); err != nil {
return err return err
} }
state.SetGlobal(m.Name) state.SetGlobal(m.Name)

View File

@ -221,12 +221,17 @@ func CleanupEnv() error {
// envGet Lua function to get an environment variable // envGet Lua function to get an environment variable
func envGet(state *luajit.State) int { func envGet(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckExactArgs(1); err != nil {
state.PushNil()
return 1
}
key, err := state.SafeToString(1)
if err != nil {
state.PushNil() state.PushNil()
return 1 return 1
} }
key := state.ToString(1)
if value, exists := globalEnvManager.Get(key); exists { if value, exists := globalEnvManager.Get(key); exists {
if err := state.PushValue(value); err != nil { if err := state.PushValue(value); err != nil {
state.PushNil() state.PushNil()
@ -239,13 +244,22 @@ func envGet(state *luajit.State) int {
// envSet Lua function to set an environment variable // envSet Lua function to set an environment variable
func envSet(state *luajit.State) int { func envSet(state *luajit.State) int {
if !state.IsString(1) || !state.IsString(2) { if err := state.CheckExactArgs(2); err != nil {
state.PushBoolean(false) state.PushBoolean(false)
return 1 return 1
} }
key := state.ToString(1) key, err := state.SafeToString(1)
value := state.ToString(2) if err != nil {
state.PushBoolean(false)
return 1
}
value, err := state.SafeToString(2)
if err != nil {
state.PushBoolean(false)
return 1
}
globalEnvManager.Set(key, value) globalEnvManager.Set(key, value)
state.PushBoolean(true) state.PushBoolean(true)
@ -255,11 +269,9 @@ func envSet(state *luajit.State) int {
// envGetAll Lua function to get all environment variables // envGetAll Lua function to get all environment variables
func envGetAll(state *luajit.State) int { func envGetAll(state *luajit.State) int {
vars := globalEnvManager.GetAll() vars := globalEnvManager.GetAll()
if err := state.PushValue(vars); err != nil {
if err := state.PushTable(vars); err != nil {
state.PushNil() state.PushNil()
} }
return 1 return 1
} }

View File

@ -112,23 +112,24 @@ func getCacheKey(fullPath string, modTime time.Time) string {
// fsReadFile reads a file and returns its contents // fsReadFile reads a file and returns its contents
func fsReadFile(state *luajit.State) int { func fsReadFile(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckExactArgs(1); err != nil {
state.PushString("fs.read_file: path must be a string") return state.PushError("fs.read_file: %v", err)
return -1 }
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.read_file: path must be string")
} }
path := state.ToString(1)
fullPath, err := ResolvePath(path) fullPath, err := ResolvePath(path)
if err != nil { if err != nil {
state.PushString("fs.read_file: " + err.Error()) return state.PushError("fs.read_file: %v", err)
return -1
} }
// Get file info for cache key and validation // Get file info for cache key and validation
info, err := os.Stat(fullPath) info, err := os.Stat(fullPath)
if err != nil { if err != nil {
state.PushString("fs.read_file: " + err.Error()) return state.PushError("fs.read_file: %v", err)
return -1
} }
// Create cache key with path and modification time // Create cache key with path and modification time
@ -154,8 +155,7 @@ func fsReadFile(state *luajit.State) int {
stats.misses++ stats.misses++
data, err := os.ReadFile(fullPath) data, err := os.ReadFile(fullPath)
if err != nil { if err != nil {
state.PushString("fs.read_file: " + err.Error()) return state.PushError("fs.read_file: %v", err)
return -1
} }
// Compress and cache the data // Compress and cache the data
@ -170,41 +170,33 @@ func fsReadFile(state *luajit.State) int {
// fsWriteFile writes data to a file // fsWriteFile writes data to a file
func fsWriteFile(state *luajit.State) int { func fsWriteFile(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckExactArgs(2); err != nil {
state.PushString("fs.write_file: path must be a string") return state.PushError("fs.write_file: %v", err)
return -1
} }
path := state.ToString(1)
if !state.IsString(2) { path, err := state.SafeToString(1)
state.PushString("fs.write_file: content must be a string") if err != nil {
return -1 return state.PushError("fs.write_file: path must be string")
}
content, err := state.SafeToString(2)
if err != nil {
return state.PushError("fs.write_file: content must be string")
} }
content := state.ToString(2)
fullPath, err := ResolvePath(path) fullPath, err := ResolvePath(path)
if err != nil { if err != nil {
state.PushString("fs.write_file: " + err.Error()) return state.PushError("fs.write_file: %v", err)
return -1
} }
// Ensure the directory exists // Ensure the directory exists
dir := filepath.Dir(fullPath) dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil { if err := os.MkdirAll(dir, 0755); err != nil {
state.PushString("fs.write_file: failed to create directory: " + err.Error()) return state.PushError("fs.write_file: failed to create directory: %v", err)
return -1
} }
err = os.WriteFile(fullPath, []byte(content), 0644) if err := os.WriteFile(fullPath, []byte(content), 0644); err != nil {
if err != nil { return state.PushError("fs.write_file: %v", err)
state.PushString("fs.write_file: " + err.Error())
return -1
}
// Invalidate cache entries for this file path
if fileCache != nil {
// We can't easily iterate through cache keys, so we'll let the cache
// naturally expire old entries when the file is read again
} }
state.PushBoolean(true) state.PushBoolean(true)
@ -213,42 +205,39 @@ func fsWriteFile(state *luajit.State) int {
// fsAppendFile appends data to a file // fsAppendFile appends data to a file
func fsAppendFile(state *luajit.State) int { func fsAppendFile(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckExactArgs(2); err != nil {
state.PushString("fs.append_file: path must be a string") return state.PushError("fs.append_file: %v", err)
return -1
} }
path := state.ToString(1)
if !state.IsString(2) { path, err := state.SafeToString(1)
state.PushString("fs.append_file: content must be a string") if err != nil {
return -1 return state.PushError("fs.append_file: path must be string")
}
content, err := state.SafeToString(2)
if err != nil {
return state.PushError("fs.append_file: content must be string")
} }
content := state.ToString(2)
fullPath, err := ResolvePath(path) fullPath, err := ResolvePath(path)
if err != nil { if err != nil {
state.PushString("fs.append_file: " + err.Error()) return state.PushError("fs.append_file: %v", err)
return -1
} }
// Ensure the directory exists // Ensure the directory exists
dir := filepath.Dir(fullPath) dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil { if err := os.MkdirAll(dir, 0755); err != nil {
state.PushString("fs.append_file: failed to create directory: " + err.Error()) return state.PushError("fs.append_file: failed to create directory: %v", err)
return -1
} }
file, err := os.OpenFile(fullPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) file, err := os.OpenFile(fullPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
state.PushString("fs.append_file: " + err.Error()) return state.PushError("fs.append_file: %v", err)
return -1
} }
defer file.Close() defer file.Close()
_, err = file.Write([]byte(content)) if _, err = file.Write([]byte(content)); err != nil {
if err != nil { return state.PushError("fs.append_file: %v", err)
state.PushString("fs.append_file: " + err.Error())
return -1
} }
state.PushBoolean(true) state.PushBoolean(true)
@ -257,16 +246,18 @@ func fsAppendFile(state *luajit.State) int {
// fsExists checks if a file or directory exists // fsExists checks if a file or directory exists
func fsExists(state *luajit.State) int { func fsExists(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckExactArgs(1); err != nil {
state.PushString("fs.exists: path must be a string") return state.PushError("fs.exists: %v", err)
return -1 }
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.exists: path must be string")
} }
path := state.ToString(1)
fullPath, err := ResolvePath(path) fullPath, err := ResolvePath(path)
if err != nil { if err != nil {
state.PushString("fs.exists: " + err.Error()) return state.PushError("fs.exists: %v", err)
return -1
} }
_, err = os.Stat(fullPath) _, err = os.Stat(fullPath)
@ -276,34 +267,32 @@ func fsExists(state *luajit.State) int {
// fsRemoveFile removes a file // fsRemoveFile removes a file
func fsRemoveFile(state *luajit.State) int { func fsRemoveFile(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckExactArgs(1); err != nil {
state.PushString("fs.remove_file: path must be a string") return state.PushError("fs.remove_file: %v", err)
return -1 }
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.remove_file: path must be string")
} }
path := state.ToString(1)
fullPath, err := ResolvePath(path) fullPath, err := ResolvePath(path)
if err != nil { if err != nil {
state.PushString("fs.remove_file: " + err.Error()) return state.PushError("fs.remove_file: %v", err)
return -1
} }
// Check if it's a directory // Check if it's a directory
info, err := os.Stat(fullPath) info, err := os.Stat(fullPath)
if err != nil { if err != nil {
state.PushString("fs.remove_file: " + err.Error()) return state.PushError("fs.remove_file: %v", err)
return -1
} }
if info.IsDir() { if info.IsDir() {
state.PushString("fs.remove_file: cannot remove directory, use remove_dir instead") return state.PushError("fs.remove_file: cannot remove directory, use remove_dir instead")
return -1
} }
err = os.Remove(fullPath) if err := os.Remove(fullPath); err != nil {
if err != nil { return state.PushError("fs.remove_file: %v", err)
state.PushString("fs.remove_file: " + err.Error())
return -1
} }
state.PushBoolean(true) state.PushBoolean(true)
@ -312,67 +301,65 @@ func fsRemoveFile(state *luajit.State) int {
// fsGetInfo gets information about a file // fsGetInfo gets information about a file
func fsGetInfo(state *luajit.State) int { func fsGetInfo(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckExactArgs(1); err != nil {
state.PushString("fs.get_info: path must be a string") return state.PushError("fs.get_info: %v", err)
return -1 }
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.get_info: path must be string")
} }
path := state.ToString(1)
fullPath, err := ResolvePath(path) fullPath, err := ResolvePath(path)
if err != nil { if err != nil {
state.PushString("fs.get_info: " + err.Error()) return state.PushError("fs.get_info: %v", err)
return -1
} }
info, err := os.Stat(fullPath) info, err := os.Stat(fullPath)
if err != nil { if err != nil {
state.PushString("fs.get_info: " + err.Error()) return state.PushError("fs.get_info: %v", err)
return -1
} }
state.NewTable() fileInfo := map[string]any{
"name": info.Name(),
"size": info.Size(),
"mode": int(info.Mode()),
"mod_time": info.ModTime().Unix(),
"is_dir": info.IsDir(),
}
state.PushString(info.Name()) if err := state.PushValue(fileInfo); err != nil {
state.SetField(-2, "name") return state.PushError("fs.get_info: %v", err)
}
state.PushNumber(float64(info.Size()))
state.SetField(-2, "size")
state.PushNumber(float64(info.Mode()))
state.SetField(-2, "mode")
state.PushNumber(float64(info.ModTime().Unix()))
state.SetField(-2, "mod_time")
state.PushBoolean(info.IsDir())
state.SetField(-2, "is_dir")
return 1 return 1
} }
// fsMakeDir creates a directory // fsMakeDir creates a directory
func fsMakeDir(state *luajit.State) int { func fsMakeDir(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckMinArgs(1); err != nil {
state.PushString("fs.make_dir: path must be a string") return state.PushError("fs.make_dir: %v", err)
return -1 }
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.make_dir: path must be string")
} }
path := state.ToString(1)
perm := os.FileMode(0755) perm := os.FileMode(0755)
if state.GetTop() >= 2 && state.IsNumber(2) { if state.GetTop() >= 2 {
perm = os.FileMode(state.ToNumber(2)) if permVal, err := state.SafeToNumber(2); err == nil {
perm = os.FileMode(permVal)
}
} }
fullPath, err := ResolvePath(path) fullPath, err := ResolvePath(path)
if err != nil { if err != nil {
state.PushString("fs.make_dir: " + err.Error()) return state.PushError("fs.make_dir: %v", err)
return -1
} }
err = os.MkdirAll(fullPath, perm) if err := os.MkdirAll(fullPath, perm); err != nil {
if err != nil { return state.PushError("fs.make_dir: %v", err)
state.PushString("fs.make_dir: " + err.Error())
return -1
} }
state.PushBoolean(true) state.PushBoolean(true)
@ -381,41 +368,42 @@ func fsMakeDir(state *luajit.State) int {
// fsListDir lists the contents of a directory // fsListDir lists the contents of a directory
func fsListDir(state *luajit.State) int { func fsListDir(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckExactArgs(1); err != nil {
state.PushString("fs.list_dir: path must be a string") return state.PushError("fs.list_dir: %v", err)
return -1 }
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.list_dir: path must be string")
} }
path := state.ToString(1)
fullPath, err := ResolvePath(path) fullPath, err := ResolvePath(path)
if err != nil { if err != nil {
state.PushString("fs.list_dir: " + err.Error()) return state.PushError("fs.list_dir: %v", err)
return -1
} }
info, err := os.Stat(fullPath) info, err := os.Stat(fullPath)
if err != nil { if err != nil {
state.PushString("fs.list_dir: " + err.Error()) return state.PushError("fs.list_dir: %v", err)
return -1
} }
if !info.IsDir() { if !info.IsDir() {
state.PushString("fs.list_dir: not a directory") return state.PushError("fs.list_dir: not a directory")
return -1
} }
files, err := os.ReadDir(fullPath) files, err := os.ReadDir(fullPath)
if err != nil { if err != nil {
state.PushString("fs.list_dir: " + err.Error()) return state.PushError("fs.list_dir: %v", err)
return -1
} }
state.NewTable() // Create array of filenames
filenames := make([]string, len(files))
for i, file := range files { for i, file := range files {
state.PushNumber(float64(i + 1)) filenames[i] = file.Name()
state.PushString(file.Name()) }
state.SetTable(-3)
if err := state.PushValue(filenames); err != nil {
return state.PushError("fs.list_dir: %v", err)
} }
return 1 return 1
@ -423,11 +411,14 @@ func fsListDir(state *luajit.State) int {
// fsRemoveDir removes a directory // fsRemoveDir removes a directory
func fsRemoveDir(state *luajit.State) int { func fsRemoveDir(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckMinArgs(1); err != nil {
state.PushString("fs.remove_dir: path must be a string") return state.PushError("fs.remove_dir: %v", err)
return -1 }
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.remove_dir: path must be string")
} }
path := state.ToString(1)
recursive := false recursive := false
if state.GetTop() >= 2 { if state.GetTop() >= 2 {
@ -436,19 +427,16 @@ func fsRemoveDir(state *luajit.State) int {
fullPath, err := ResolvePath(path) fullPath, err := ResolvePath(path)
if err != nil { if err != nil {
state.PushString("fs.remove_dir: " + err.Error()) return state.PushError("fs.remove_dir: %v", err)
return -1
} }
info, err := os.Stat(fullPath) info, err := os.Stat(fullPath)
if err != nil { if err != nil {
state.PushString("fs.remove_dir: " + err.Error()) return state.PushError("fs.remove_dir: %v", err)
return -1
} }
if !info.IsDir() { if !info.IsDir() {
state.PushString("fs.remove_dir: not a directory") return state.PushError("fs.remove_dir: not a directory")
return -1
} }
if recursive { if recursive {
@ -458,8 +446,7 @@ func fsRemoveDir(state *luajit.State) int {
} }
if err != nil { if err != nil {
state.PushString("fs.remove_dir: " + err.Error()) return state.PushError("fs.remove_dir: %v", err)
return -1
} }
state.PushBoolean(true) state.PushBoolean(true)
@ -468,19 +455,17 @@ func fsRemoveDir(state *luajit.State) int {
// fsJoinPaths joins path components // fsJoinPaths joins path components
func fsJoinPaths(state *luajit.State) int { func fsJoinPaths(state *luajit.State) int {
nargs := state.GetTop() if err := state.CheckMinArgs(1); err != nil {
if nargs < 1 { return state.PushError("fs.join_paths: %v", err)
state.PushString("fs.join_paths: at least one path component required")
return -1
} }
components := make([]string, nargs) components := make([]string, state.GetTop())
for i := 1; i <= nargs; i++ { for i := 1; i <= state.GetTop(); i++ {
if !state.IsString(i) { comp, err := state.SafeToString(i)
state.PushString("fs.join_paths: all arguments must be strings") if err != nil {
return -1 return state.PushError("fs.join_paths: all arguments must be strings")
} }
components[i-1] = state.ToString(i) components[i-1] = comp
} }
result := filepath.Join(components...) result := filepath.Join(components...)
@ -492,11 +477,14 @@ func fsJoinPaths(state *luajit.State) int {
// fsDirName returns the directory portion of a path // fsDirName returns the directory portion of a path
func fsDirName(state *luajit.State) int { func fsDirName(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckExactArgs(1); err != nil {
state.PushString("fs.dir_name: path must be a string") return state.PushError("fs.dir_name: %v", err)
return -1 }
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.dir_name: path must be string")
} }
path := state.ToString(1)
dir := filepath.Dir(path) dir := filepath.Dir(path)
dir = strings.ReplaceAll(dir, "\\", "/") dir = strings.ReplaceAll(dir, "\\", "/")
@ -507,28 +495,32 @@ func fsDirName(state *luajit.State) int {
// fsBaseName returns the file name portion of a path // fsBaseName returns the file name portion of a path
func fsBaseName(state *luajit.State) int { func fsBaseName(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckExactArgs(1); err != nil {
state.PushString("fs.base_name: path must be a string") return state.PushError("fs.base_name: %v", err)
return -1 }
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.base_name: path must be string")
} }
path := state.ToString(1)
base := filepath.Base(path) base := filepath.Base(path)
state.PushString(base) state.PushString(base)
return 1 return 1
} }
// fsExtension returns the file extension // fsExtension returns the file extension
func fsExtension(state *luajit.State) int { func fsExtension(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckExactArgs(1); err != nil {
state.PushString("fs.extension: path must be a string") return state.PushError("fs.extension: %v", err)
return -1 }
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.extension: path must be string")
} }
path := state.ToString(1)
ext := filepath.Ext(path) ext := filepath.Ext(path)
state.PushString(ext) state.PushString(ext)
return 1 return 1
} }

View File

@ -94,25 +94,26 @@ func ApplyResponse(resp *Response, ctx *fasthttp.RequestCtx) {
// httpRequest makes an HTTP request and returns the result to Lua // httpRequest makes an HTTP request and returns the result to Lua
func httpRequest(state *luajit.State) int { func httpRequest(state *luajit.State) int {
// Get method (required) if err := state.CheckMinArgs(2); err != nil {
if !state.IsString(1) { return state.PushError("http.client.request: %v", err)
state.PushString("http.client.request: method must be a string")
return -1
} }
method := strings.ToUpper(state.ToString(1))
// Get URL (required) // Get method and URL
if !state.IsString(2) { method, err := state.SafeToString(1)
state.PushString("http.client.request: url must be a string") if err != nil {
return -1 return state.PushError("http.client.request: method must be string")
}
method = strings.ToUpper(method)
urlStr, err := state.SafeToString(2)
if err != nil {
return state.PushError("http.client.request: url must be string")
} }
urlStr := state.ToString(2)
// Parse URL to check if it's valid // Parse URL to check if it's valid
parsedURL, err := url.Parse(urlStr) parsedURL, err := url.Parse(urlStr)
if err != nil { if err != nil {
state.PushString("Invalid URL: " + err.Error()) return state.PushError("Invalid URL: %v", err)
return -1
} }
// Get client configuration // Get client configuration
@ -120,8 +121,7 @@ func httpRequest(state *luajit.State) int {
// Check if remote connections are allowed // Check if remote connections are allowed
if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") { if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") {
state.PushString("Remote connections are not allowed") return state.PushError("Remote connections are not allowed")
return -1
} }
// Use bytebufferpool for request and response // Use bytebufferpool for request and response
@ -139,13 +139,13 @@ func httpRequest(state *luajit.State) int {
if state.GetTop() >= 3 && !state.IsNil(3) { if state.GetTop() >= 3 && !state.IsNil(3) {
if state.IsString(3) { if state.IsString(3) {
// String body // String body
req.SetBodyString(state.ToString(3)) bodyStr, _ := state.SafeToString(3)
req.SetBodyString(bodyStr)
} else if state.IsTable(3) { } else if state.IsTable(3) {
// Table body - convert to JSON // Table body - convert to JSON
luaTable, err := state.ToTable(3) luaTable, err := state.SafeToTable(3)
if err != nil { if err != nil {
state.PushString("Failed to parse body table: " + err.Error()) return state.PushError("Failed to parse body table: %v", err)
return -1
} }
// Use bytebufferpool for JSON serialization // Use bytebufferpool for JSON serialization
@ -153,42 +153,37 @@ func httpRequest(state *luajit.State) int {
defer bytebufferpool.Put(buf) defer bytebufferpool.Put(buf)
if err := json.NewEncoder(buf).Encode(luaTable); err != nil { if err := json.NewEncoder(buf).Encode(luaTable); err != nil {
state.PushString("Failed to convert body to JSON: " + err.Error()) return state.PushError("Failed to convert body to JSON: %v", err)
return -1
} }
req.SetBody(buf.Bytes()) req.SetBody(buf.Bytes())
req.Header.SetContentType("application/json") req.Header.SetContentType("application/json")
} else { } else {
state.PushString("Body must be a string or table") return state.PushError("Body must be a string or table")
return -1
} }
} }
// Process options (headers, timeout, etc.) // Process options (headers, timeout, etc.)
timeout := config.DefaultTimeout timeout := config.DefaultTimeout
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsTable(4) { if state.GetTop() >= 4 && !state.IsNil(4) && state.IsTable(4) {
// Process headers // Process headers using ForEachTableKV
state.GetField(4, "headers") if headers, ok := state.GetFieldTable(4, "headers"); ok {
if state.IsTable(-1) { if headerMap, ok := headers.(map[string]string); ok {
// Iterate through headers for name, value := range headerMap {
state.PushNil() // Start iteration req.Header.Set(name, value)
for state.Next(-2) { }
// Stack now has key at -2 and value at -1 } else if headerMapAny, ok := headers.(map[string]any); ok {
if state.IsString(-2) && state.IsString(-1) { for name, value := range headerMapAny {
headerName := state.ToString(-2) if valueStr, ok := value.(string); ok {
headerValue := state.ToString(-1) req.Header.Set(name, valueStr)
req.Header.Set(headerName, headerValue) }
} }
state.Pop(1) // Pop value, leave key for next iteration
} }
} }
state.Pop(1) // Pop headers table
// Get timeout // Get timeout
state.GetField(4, "timeout") if timeoutVal := state.GetFieldNumber(4, "timeout", 0); timeoutVal > 0 {
if state.IsNumber(-1) { requestTimeout := time.Duration(timeoutVal) * time.Second
requestTimeout := time.Duration(state.ToNumber(-1)) * time.Second
// Apply max timeout if configured // Apply max timeout if configured
if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout { if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout {
@ -197,38 +192,34 @@ func httpRequest(state *luajit.State) int {
timeout = requestTimeout timeout = requestTimeout
} }
} }
state.Pop(1) // Pop timeout
// Process query parameters // Process query parameters
state.GetField(4, "query") if query, ok := state.GetFieldTable(4, "query"); ok {
if state.IsTable(-1) {
// Create URL args
args := req.URI().QueryArgs() args := req.URI().QueryArgs()
// Iterate through query params if queryMap, ok := query.(map[string]string); ok {
state.PushNil() // Start iteration for name, value := range queryMap {
for state.Next(-2) { args.Add(name, value)
if state.IsString(-2) { }
paramName := state.ToString(-2) } else if queryMapAny, ok := query.(map[string]any); ok {
for name, value := range queryMapAny {
// Handle different value types switch v := value.(type) {
if state.IsString(-1) { case string:
args.Add(paramName, state.ToString(-1)) args.Add(name, v)
} else if state.IsNumber(-1) { case int:
args.Add(paramName, strings.TrimRight(strings.TrimRight( args.Add(name, fmt.Sprintf("%d", v))
state.ToString(-1), "0"), ".")) case float64:
} else if state.IsBoolean(-1) { args.Add(name, strings.TrimRight(strings.TrimRight(fmt.Sprintf("%.6f", v), "0"), "."))
if state.ToBoolean(-1) { case bool:
args.Add(paramName, "true") if v {
args.Add(name, "true")
} else { } else {
args.Add(paramName, "false") args.Add(name, "false")
} }
} }
} }
state.Pop(1) // Pop value, leave key for next iteration
} }
} }
state.Pop(1) // Pop query table
} }
// Create context with timeout // Create context with timeout
@ -242,26 +233,18 @@ func httpRequest(state *luajit.State) int {
if errors.Is(err, fasthttp.ErrTimeout) { if errors.Is(err, fasthttp.ErrTimeout) {
errStr = "Request timed out after " + timeout.String() errStr = "Request timed out after " + timeout.String()
} }
state.PushString(errStr) return state.PushError("%s", errStr)
return -1
} }
// Create response table // Create response using TableBuilder
state.NewTable() builder := state.NewTableBuilder()
// Set status code // Set status code and text
state.PushNumber(float64(resp.StatusCode())) builder.SetNumber("status", float64(resp.StatusCode()))
state.SetField(-2, "status") builder.SetString("status_text", fasthttp.StatusMessage(resp.StatusCode()))
// Set status text
statusText := fasthttp.StatusMessage(resp.StatusCode())
state.PushString(statusText)
state.SetField(-2, "status_text")
// Set body // Set body
var respBody []byte var respBody []byte
// Apply size limits to response
if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize { if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize {
// Make a limited copy // Make a limited copy
respBody = make([]byte, config.MaxResponseSize) respBody = make([]byte, config.MaxResponseSize)
@ -270,32 +253,28 @@ func httpRequest(state *luajit.State) int {
respBody = resp.Body() respBody = resp.Body()
} }
state.PushString(string(respBody)) builder.SetString("body", string(respBody))
state.SetField(-2, "body")
// Parse body as JSON if content type is application/json // Parse body as JSON if content type is application/json
contentType := string(resp.Header.ContentType()) contentType := string(resp.Header.ContentType())
if strings.Contains(contentType, "application/json") { if strings.Contains(contentType, "application/json") {
var jsonData any var jsonData any
if err := json.Unmarshal(respBody, &jsonData); err == nil { if err := json.Unmarshal(respBody, &jsonData); err == nil {
if err := state.PushValue(jsonData); err == nil { builder.SetTable("json", jsonData)
state.SetField(-2, "json")
}
} }
} }
// Set headers // Set headers
state.NewTable() headers := make(map[string]string)
resp.Header.VisitAll(func(key, value []byte) { resp.Header.VisitAll(func(key, value []byte) {
state.PushString(string(value)) headers[string(key)] = string(value)
state.SetField(-2, string(key))
}) })
state.SetField(-2, "headers") builder.SetTable("headers", headers)
// Create ok field (true if status code is 2xx) // Create ok field (true if status code is 2xx)
state.PushBoolean(resp.StatusCode() >= 200 && resp.StatusCode() < 300) builder.SetBool("ok", resp.StatusCode() >= 200 && resp.StatusCode() < 300)
state.SetField(-2, "ok")
builder.Build()
return 1 return 1
} }
@ -303,8 +282,10 @@ func httpRequest(state *luajit.State) int {
func generateToken(state *luajit.State) int { func generateToken(state *luajit.State) int {
// Get the length from the Lua arguments (default to 32) // Get the length from the Lua arguments (default to 32)
length := 32 length := 32
if state.GetTop() >= 1 && state.IsNumber(1) { if state.GetTop() >= 1 {
length = int(state.ToNumber(1)) if lengthVal, err := state.SafeToNumber(1); err == nil {
length = int(lengthVal)
}
} }
// Enforce minimum length for security // Enforce minimum length for security

View File

@ -1,8 +1,6 @@
package runner package runner
import ( import (
"fmt"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go" luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/alexedwards/argon2id" "github.com/alexedwards/argon2id"
) )
@ -20,12 +18,14 @@ func RegisterPasswordFunctions(state *luajit.State) error {
// passwordHash implements the Argon2id password hashing using alexedwards/argon2id // passwordHash implements the Argon2id password hashing using alexedwards/argon2id
func passwordHash(state *luajit.State) int { func passwordHash(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckMinArgs(1); err != nil {
state.PushString("password_hash error: expected string password") return state.PushError("password_hash: %v", err)
return 1
} }
password := state.ToString(1) password, err := state.SafeToString(1)
if err != nil {
return state.PushError("password_hash: password must be string")
}
params := &argon2id.Params{ params := &argon2id.Params{
Memory: 128 * 1024, Memory: 128 * 1024,
@ -35,42 +35,32 @@ func passwordHash(state *luajit.State) int {
KeyLength: 32, KeyLength: 32,
} }
if state.IsTable(2) { if state.GetTop() >= 2 && state.IsTable(2) {
state.GetField(2, "memory") // Use new field getters with validation
if state.IsNumber(-1) { if memory := state.GetFieldNumber(2, "memory", 0); memory > 0 {
params.Memory = max(uint32(state.ToNumber(-1)), 8*1024) params.Memory = max(uint32(memory), 8*1024)
} }
state.Pop(1)
state.GetField(2, "iterations") if iterations := state.GetFieldNumber(2, "iterations", 0); iterations > 0 {
if state.IsNumber(-1) { params.Iterations = max(uint32(iterations), 1)
params.Iterations = max(uint32(state.ToNumber(-1)), 1)
} }
state.Pop(1)
state.GetField(2, "parallelism") if parallelism := state.GetFieldNumber(2, "parallelism", 0); parallelism > 0 {
if state.IsNumber(-1) { params.Parallelism = max(uint8(parallelism), 1)
params.Parallelism = max(uint8(state.ToNumber(-1)), 1)
} }
state.Pop(1)
state.GetField(2, "salt_length") if saltLength := state.GetFieldNumber(2, "salt_length", 0); saltLength > 0 {
if state.IsNumber(-1) { params.SaltLength = max(uint32(saltLength), 8)
params.SaltLength = max(uint32(state.ToNumber(-1)), 8)
} }
state.Pop(1)
state.GetField(2, "key_length") if keyLength := state.GetFieldNumber(2, "key_length", 0); keyLength > 0 {
if state.IsNumber(-1) { params.KeyLength = max(uint32(keyLength), 16)
params.KeyLength = max(uint32(state.ToNumber(-1)), 16)
} }
state.Pop(1)
} }
hash, err := argon2id.CreateHash(password, params) hash, err := argon2id.CreateHash(password, params)
if err != nil { if err != nil {
state.PushString(fmt.Sprintf("password_hash error: %v", err)) return state.PushError("password_hash: %v", err)
return 1
} }
state.PushString(hash) state.PushString(hash)
@ -79,13 +69,22 @@ func passwordHash(state *luajit.State) int {
// passwordVerify verifies a password against a hash // passwordVerify verifies a password against a hash
func passwordVerify(state *luajit.State) int { func passwordVerify(state *luajit.State) int {
if !state.IsString(1) || !state.IsString(2) { if err := state.CheckExactArgs(2); err != nil {
state.PushBoolean(false) state.PushBoolean(false)
return 1 return 1
} }
password := state.ToString(1) password, err := state.SafeToString(1)
hash := state.ToString(2) if err != nil {
state.PushBoolean(false)
return 1
}
hash, err := state.SafeToString(2)
if err != nil {
state.PushBoolean(false)
return 1
}
match, err := argon2id.ComparePasswordAndHash(password, hash) match, err := argon2id.ComparePasswordAndHash(password, hash)
if err != nil { if err != nil {

View File

@ -156,7 +156,7 @@ func (r *Runner) createState(index int) (*State, error) {
logger.Debugf("Creating Lua state %d", index) logger.Debugf("Creating Lua state %d", index)
} }
L := luajit.New() L := luajit.New(true) // Explicitly open standard libraries
if L == nil { if L == nil {
return nil, errors.New("failed to create Lua state") return nil, errors.New("failed to create Lua state")
} }

View File

@ -130,42 +130,30 @@ func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
// Execute runs a Lua script in the sandbox with the given context // Execute runs a Lua script in the sandbox with the given context
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) { func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
state.GetGlobal("__execute_script") // Use CallGlobal for cleaner function calling
if !state.IsFunction(-1) { results, err := state.CallGlobal("__execute_script",
state.Pop(1) func() any {
return nil, ErrSandboxNotInitialized if err := state.LoadBytecode(bytecode, "script"); err != nil {
} return nil
}
return "loaded_script"
}(),
ctx.Values)
if err := state.LoadBytecode(bytecode, "script"); err != nil { if err != nil {
state.Pop(1) // Pop the __execute_script function
return nil, fmt.Errorf("failed to load script: %w", err)
}
// Push context values
if err := state.PushTable(ctx.Values); err != nil {
state.Pop(2) // Pop bytecode and __execute_script
return nil, err
}
// Execute with 2 args, 1 result
if err := state.Call(2, 1); err != nil {
return nil, fmt.Errorf("script execution failed: %w", err) return nil, fmt.Errorf("script execution failed: %w", err)
} }
body, err := state.ToValue(-1)
state.Pop(1)
response := NewResponse() response := NewResponse()
if err == nil { if len(results) > 0 {
response.Body = body response.Body = results[0]
} }
extractHTTPResponseData(state, response) extractHTTPResponseData(state, response)
return response, nil return response, nil
} }
// extractResponseData pulls response info from the Lua state // extractResponseData pulls response info from the Lua state using new API
func extractHTTPResponseData(state *luajit.State, response *Response) { func extractHTTPResponseData(state *luajit.State, response *Response) {
state.GetGlobal("__http_response") state.GetGlobal("__http_response")
if !state.IsTable(-1) { if !state.IsTable(-1) {
@ -173,162 +161,113 @@ func extractHTTPResponseData(state *luajit.State, response *Response) {
return return
} }
// Extract status // Use new field getters with defaults
state.GetField(-1, "status") response.Status = int(state.GetFieldNumber(-1, "status", 200))
if state.IsNumber(-1) {
response.Status = int(state.ToNumber(-1))
}
state.Pop(1)
// Extract headers // Extract headers using ForEachTableKV
state.GetField(-1, "headers") if headerTable, ok := state.GetFieldTable(-1, "headers"); ok {
if state.IsTable(-1) { if headers, ok := headerTable.(map[string]any); ok {
state.PushNil() // Start iteration for k, v := range headers {
for state.Next(-2) { if str, ok := v.(string); ok {
if state.IsString(-2) && state.IsString(-1) { response.Headers[k] = str
key := state.ToString(-2) }
value := state.ToString(-1)
response.Headers[key] = value
} }
state.Pop(1)
} }
} }
state.Pop(1)
// Extract cookies // Extract cookies using ForEachArray
state.GetField(-1, "cookies") state.GetField(-1, "cookies")
if state.IsTable(-1) { if state.IsTable(-1) {
length := state.GetTableLength(-1) state.ForEachArray(-1, func(i int, s *luajit.State) bool {
for i := 1; i <= length; i++ { if s.IsTable(-1) {
state.PushNumber(float64(i)) extractCookie(s, response)
state.GetTable(-2)
if state.IsTable(-1) {
extractCookie(state, response)
} }
state.Pop(1) return true
} })
} }
state.Pop(1) state.Pop(1)
// Extract metadata // Extract metadata
state.GetField(-1, "metadata") if metadata, ok := state.GetFieldTable(-1, "metadata"); ok {
if state.IsTable(-1) { if metaMap, ok := metadata.(map[string]any); ok {
table, err := state.ToTable(-1) maps.Copy(response.Metadata, metaMap)
if err == nil {
maps.Copy(response.Metadata, table)
} }
} }
state.Pop(1)
// Extract session data // Extract session data
state.GetField(-1, "session") if session, ok := state.GetFieldTable(-1, "session"); ok {
if state.IsTable(-1) { if sessMap, ok := session.(map[string]any); ok {
table, err := state.ToTable(-1) maps.Copy(response.SessionData, sessMap)
if err == nil {
maps.Copy(response.SessionData, table)
} }
} }
state.Pop(1)
state.Pop(1) // Pop __http_response state.Pop(1) // Pop __http_response
} }
// extractCookie pulls cookie data from the current table on the stack // extractCookie pulls cookie data from the current table on the stack using new API
func extractCookie(state *luajit.State, response *Response) { func extractCookie(state *luajit.State, response *Response) {
cookie := fasthttp.AcquireCookie() cookie := fasthttp.AcquireCookie()
// Get name (required) // Use new field getters with defaults
state.GetField(-1, "name") name := state.GetFieldString(-1, "name", "")
if !state.IsString(-1) { if name == "" {
state.Pop(1)
fasthttp.ReleaseCookie(cookie) fasthttp.ReleaseCookie(cookie)
return return
} }
cookie.SetKey(state.ToString(-1))
state.Pop(1)
// Get value cookie.SetKey(name)
state.GetField(-1, "value") cookie.SetValue(state.GetFieldString(-1, "value", ""))
if state.IsString(-1) { cookie.SetPath(state.GetFieldString(-1, "path", "/"))
cookie.SetValue(state.ToString(-1)) cookie.SetDomain(state.GetFieldString(-1, "domain", ""))
} cookie.SetHTTPOnly(state.GetFieldBool(-1, "http_only", false))
state.Pop(1) cookie.SetSecure(state.GetFieldBool(-1, "secure", false))
cookie.SetMaxAge(int(state.GetFieldNumber(-1, "max_age", 0)))
// Get path
state.GetField(-1, "path")
if state.IsString(-1) {
cookie.SetPath(state.ToString(-1))
} else {
cookie.SetPath("/") // Default
}
state.Pop(1)
// Get domain
state.GetField(-1, "domain")
if state.IsString(-1) {
cookie.SetDomain(state.ToString(-1))
}
state.Pop(1)
// Get other parameters
state.GetField(-1, "http_only")
if state.IsBoolean(-1) {
cookie.SetHTTPOnly(state.ToBoolean(-1))
}
state.Pop(1)
state.GetField(-1, "secure")
if state.IsBoolean(-1) {
cookie.SetSecure(state.ToBoolean(-1))
}
state.Pop(1)
state.GetField(-1, "max_age")
if state.IsNumber(-1) {
cookie.SetMaxAge(int(state.ToNumber(-1)))
}
state.Pop(1)
response.Cookies = append(response.Cookies, cookie) response.Cookies = append(response.Cookies, cookie)
} }
// jsonMarshal converts a Lua value to a JSON string // jsonMarshal converts a Lua value to a JSON string with validation
func jsonMarshal(state *luajit.State) int { func jsonMarshal(state *luajit.State) int {
value, err := state.ToValue(1) if err := state.CheckExactArgs(1); err != nil {
return state.PushError("json marshal: %v", err)
}
value, err := state.SafeToTable(1)
if err != nil { if err != nil {
state.PushString(fmt.Sprintf("json marshal error: %v", err)) // Try as generic value if not a table
return 1 value, err = state.ToValue(1)
if err != nil {
return state.PushError("json marshal error: %v", err)
}
} }
bytes, err := json.Marshal(value) bytes, err := json.Marshal(value)
if err != nil { if err != nil {
state.PushString(fmt.Sprintf("json marshal error: %v", err)) return state.PushError("json marshal error: %v", err)
return 1
} }
state.PushString(string(bytes)) state.PushString(string(bytes))
return 1 return 1
} }
// jsonUnmarshal converts a JSON string to a Lua value // jsonUnmarshal converts a JSON string to a Lua value with validation
func jsonUnmarshal(state *luajit.State) int { func jsonUnmarshal(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckExactArgs(1); err != nil {
state.PushString("json unmarshal error: expected string") return state.PushError("json unmarshal: %v", err)
return 1 }
jsonStr, err := state.SafeToString(1)
if err != nil {
return state.PushError("json unmarshal: expected string, got %s", state.GetType(1))
} }
jsonStr := state.ToString(1)
var value any var value any
err := json.Unmarshal([]byte(jsonStr), &value) if err := json.Unmarshal([]byte(jsonStr), &value); err != nil {
if err != nil { return state.PushError("json unmarshal error: %v", err)
state.PushString(fmt.Sprintf("json unmarshal error: %v", err))
return 1
} }
if err := state.PushValue(value); err != nil { if err := state.PushValue(value); err != nil {
state.PushString(fmt.Sprintf("json unmarshal error: %v", err)) return state.PushError("json unmarshal error: %v", err)
return 1
} }
return 1 return 1
} }

View File

@ -111,20 +111,24 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
// sqlQuery executes a SQL query and returns results // sqlQuery executes a SQL query and returns results
func sqlQuery(state *luajit.State) int { func sqlQuery(state *luajit.State) int {
// Get required parameters if err := state.CheckMinArgs(2); err != nil {
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) { return state.PushError("sqlite.query: %v", err)
state.PushString("sqlite.query: requires database name and query")
return -1
} }
dbName := state.ToString(1) dbName, err := state.SafeToString(1)
query := state.ToString(2) if err != nil {
return state.PushError("sqlite.query: database name must be string")
}
query, err := state.SafeToString(2)
if err != nil {
return state.PushError("sqlite.query: query must be string")
}
// Get pool // Get pool
pool, err := getPool(dbName) pool, err := getPool(dbName)
if err != nil { if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) return state.PushError("sqlite.query: %v", err)
return -1
} }
// Get connection with timeout // Get connection with timeout
@ -133,8 +137,7 @@ func sqlQuery(state *luajit.State) int {
conn, err := pool.Take(ctx) conn, err := pool.Take(ctx)
if err != nil { if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: connection timeout: %s", err.Error())) return state.PushError("sqlite.query: connection timeout: %v", err)
return -1
} }
defer pool.Put(conn) defer pool.Put(conn)
@ -145,8 +148,7 @@ func sqlQuery(state *luajit.State) int {
// Set up parameters if provided // Set up parameters if provided
if state.GetTop() >= 3 && !state.IsNil(3) { if state.GetTop() >= 3 && !state.IsNil(3) {
if err := setupParams(state, 3, &execOpts); err != nil { if err := setupParams(state, 3, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) return state.PushError("sqlite.query: %v", err)
return -1
} }
} }
@ -182,19 +184,12 @@ func sqlQuery(state *luajit.State) int {
// Execute query // Execute query
if err := sqlitex.Execute(conn, query, &execOpts); err != nil { if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) return state.PushError("sqlite.query: %v", err)
return -1
} }
// Create result table // Create result using specific map type and PushValue
state.NewTable() if err := state.PushValue(rows); err != nil {
for i, row := range rows { return state.PushError("sqlite.query: %v", err)
state.PushNumber(float64(i + 1))
if err := state.PushTable(row); err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
state.SetTable(-3)
} }
return 1 return 1
@ -202,20 +197,24 @@ func sqlQuery(state *luajit.State) int {
// sqlExec executes a SQL statement without returning results // sqlExec executes a SQL statement without returning results
func sqlExec(state *luajit.State) int { func sqlExec(state *luajit.State) int {
// Get required parameters if err := state.CheckMinArgs(2); err != nil {
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) { return state.PushError("sqlite.exec: %v", err)
state.PushString("sqlite.exec: requires database name and query")
return -1
} }
dbName := state.ToString(1) dbName, err := state.SafeToString(1)
query := state.ToString(2) if err != nil {
return state.PushError("sqlite.exec: database name must be string")
}
query, err := state.SafeToString(2)
if err != nil {
return state.PushError("sqlite.exec: query must be string")
}
// Get pool // Get pool
pool, err := getPool(dbName) pool, err := getPool(dbName)
if err != nil { if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) return state.PushError("sqlite.exec: %v", err)
return -1
} }
// Get connection with timeout // Get connection with timeout
@ -224,8 +223,7 @@ func sqlExec(state *luajit.State) int {
conn, err := pool.Take(ctx) conn, err := pool.Take(ctx)
if err != nil { if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: connection timeout: %s", err.Error())) return state.PushError("sqlite.exec: connection timeout: %v", err)
return -1
} }
defer pool.Put(conn) defer pool.Put(conn)
@ -235,8 +233,7 @@ func sqlExec(state *luajit.State) int {
// Fast path for multi-statement scripts // Fast path for multi-statement scripts
if strings.Contains(query, ";") && !hasParams { if strings.Contains(query, ";") && !hasParams {
if err := sqlitex.ExecScript(conn, query); err != nil { if err := sqlitex.ExecScript(conn, query); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) return state.PushError("sqlite.exec: %v", err)
return -1
} }
state.PushNumber(float64(conn.Changes())) state.PushNumber(float64(conn.Changes()))
return 1 return 1
@ -245,8 +242,7 @@ func sqlExec(state *luajit.State) int {
// Fast path for simple queries with no parameters // Fast path for simple queries with no parameters
if !hasParams { if !hasParams {
if err := sqlitex.Execute(conn, query, nil); err != nil { if err := sqlitex.Execute(conn, query, nil); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) return state.PushError("sqlite.exec: %v", err)
return -1
} }
state.PushNumber(float64(conn.Changes())) state.PushNumber(float64(conn.Changes()))
return 1 return 1
@ -255,14 +251,12 @@ func sqlExec(state *luajit.State) int {
// Create execution options for parameterized query // Create execution options for parameterized query
var execOpts sqlitex.ExecOptions var execOpts sqlitex.ExecOptions
if err := setupParams(state, 3, &execOpts); err != nil { if err := setupParams(state, 3, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) return state.PushError("sqlite.exec: %v", err)
return -1
} }
// Execute with parameters // Execute with parameters
if err := sqlitex.Execute(conn, query, &execOpts); err != nil { if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) return state.PushError("sqlite.exec: %v", err)
return -1
} }
// Return affected rows // Return affected rows
@ -273,11 +267,17 @@ func sqlExec(state *luajit.State) int {
// setupParams configures execution options with parameters from Lua // setupParams configures execution options with parameters from Lua
func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error { func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error {
if state.IsTable(paramIndex) { if state.IsTable(paramIndex) {
params, err := state.ToTable(paramIndex) paramsAny, err := state.SafeToTable(paramIndex)
if err != nil { if err != nil {
return fmt.Errorf("invalid parameters: %w", err) return fmt.Errorf("invalid parameters: %w", err)
} }
// Type assert to map[string]any
params, ok := paramsAny.(map[string]any)
if !ok {
return fmt.Errorf("parameters must be a table")
}
// Check for array-style params // Check for array-style params
if arr, ok := params[""]; ok { if arr, ok := params[""]; ok {
if arrParams, ok := arr.([]any); ok { if arrParams, ok := arr.([]any); ok {
@ -321,20 +321,24 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti
// sqlGetOne executes a query and returns only the first row // sqlGetOne executes a query and returns only the first row
func sqlGetOne(state *luajit.State) int { func sqlGetOne(state *luajit.State) int {
// Get required parameters if err := state.CheckMinArgs(2); err != nil {
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) { return state.PushError("sqlite.get_one: %v", err)
state.PushString("sqlite.get_one: requires database name and query")
return -1
} }
dbName := state.ToString(1) dbName, err := state.SafeToString(1)
query := state.ToString(2) if err != nil {
return state.PushError("sqlite.get_one: database name must be string")
}
query, err := state.SafeToString(2)
if err != nil {
return state.PushError("sqlite.get_one: query must be string")
}
// Get pool // Get pool
pool, err := getPool(dbName) pool, err := getPool(dbName)
if err != nil { if err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error())) return state.PushError("sqlite.get_one: %v", err)
return -1
} }
// Get connection with timeout // Get connection with timeout
@ -343,8 +347,7 @@ func sqlGetOne(state *luajit.State) int {
conn, err := pool.Take(ctx) conn, err := pool.Take(ctx)
if err != nil { if err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: connection timeout: %s", err.Error())) return state.PushError("sqlite.get_one: connection timeout: %v", err)
return -1
} }
defer pool.Put(conn) defer pool.Put(conn)
@ -355,8 +358,7 @@ func sqlGetOne(state *luajit.State) int {
// Set up parameters if provided // Set up parameters if provided
if state.GetTop() >= 3 && !state.IsNil(3) { if state.GetTop() >= 3 && !state.IsNil(3) {
if err := setupParams(state, 3, &execOpts); err != nil { if err := setupParams(state, 3, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error())) return state.PushError("sqlite.get_one: %v", err)
return -1
} }
} }
@ -395,17 +397,15 @@ func sqlGetOne(state *luajit.State) int {
// Execute query // Execute query
if err := sqlitex.Execute(conn, query, &execOpts); err != nil { if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error())) return state.PushError("sqlite.get_one: %v", err)
return -1
} }
// Return result or nil if no rows // Return result or nil if no rows
if result == nil { if result == nil {
state.PushNil() state.PushNil()
} else { } else {
if err := state.PushTable(result); err != nil { if err := state.PushValue(result); err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error())) return state.PushError("sqlite.get_one: %v", err)
return -1
} }
} }

View File

@ -35,12 +35,17 @@ func RegisterUtilFunctions(state *luajit.State) error {
// htmlSpecialChars converts special characters to HTML entities // htmlSpecialChars converts special characters to HTML entities
func htmlSpecialChars(state *luajit.State) int { func htmlSpecialChars(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckExactArgs(1); err != nil {
state.PushNil()
return 1
}
input, err := state.SafeToString(1)
if err != nil {
state.PushNil() state.PushNil()
return 1 return 1
} }
input := state.ToString(1)
result := html.EscapeString(input) result := html.EscapeString(input)
state.PushString(result) state.PushString(result)
return 1 return 1
@ -48,12 +53,17 @@ func htmlSpecialChars(state *luajit.State) int {
// htmlEntities is a more comprehensive version of htmlSpecialChars // htmlEntities is a more comprehensive version of htmlSpecialChars
func htmlEntities(state *luajit.State) int { func htmlEntities(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckExactArgs(1); err != nil {
state.PushNil()
return 1
}
input, err := state.SafeToString(1)
if err != nil {
state.PushNil() state.PushNil()
return 1 return 1
} }
input := state.ToString(1)
// First use HTML escape for standard entities // First use HTML escape for standard entities
result := html.EscapeString(input) result := html.EscapeString(input)
@ -86,12 +96,17 @@ func htmlEntities(state *luajit.State) int {
// base64Encode encodes a string to base64 // base64Encode encodes a string to base64
func base64Encode(state *luajit.State) int { func base64Encode(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckExactArgs(1); err != nil {
state.PushNil()
return 1
}
input, err := state.SafeToString(1)
if err != nil {
state.PushNil() state.PushNil()
return 1 return 1
} }
input := state.ToString(1)
result := base64.StdEncoding.EncodeToString([]byte(input)) result := base64.StdEncoding.EncodeToString([]byte(input))
state.PushString(result) state.PushString(result)
return 1 return 1
@ -99,12 +114,17 @@ func base64Encode(state *luajit.State) int {
// base64Decode decodes a base64 string // base64Decode decodes a base64 string
func base64Decode(state *luajit.State) int { func base64Decode(state *luajit.State) int {
if !state.IsString(1) { if err := state.CheckExactArgs(1); err != nil {
state.PushNil()
return 1
}
input, err := state.SafeToString(1)
if err != nil {
state.PushNil() state.PushNil()
return 1 return 1
} }
input := state.ToString(1)
result, err := base64.StdEncoding.DecodeString(input) result, err := base64.StdEncoding.DecodeString(input)
if err != nil { if err != nil {
state.PushNil() state.PushNil()