migrate to new LJTG API

This commit is contained in:
Sky Johnson 2025-06-02 11:15:56 -05:00
parent cc6a7675d8
commit e2b1b932ff
10 changed files with 514 additions and 558 deletions

View File

@ -78,16 +78,21 @@ func CleanupCrypto(state *luajit.State) {
// cryptoHash generates hash digests using various algorithms
func cryptoHash(state *luajit.State) int {
if !state.IsString(1) || !state.IsString(2) {
state.PushString("hash: expected (string data, string algorithm)")
return 1
if err := state.CheckMinArgs(2); err != nil {
return state.PushError("hash: %v", err)
}
data := state.ToString(1)
algorithm := state.ToString(2)
data, err := state.SafeToString(1)
if err != nil {
return state.PushError("hash: data must be string")
}
algorithm, err := state.SafeToString(2)
if err != nil {
return state.PushError("hash: algorithm must be string")
}
var h hash.Hash
switch algorithm {
case "md5":
h = md5.New()
@ -98,8 +103,7 @@ func cryptoHash(state *luajit.State) int {
case "sha512":
h = sha512.New()
default:
state.PushString(fmt.Sprintf("unsupported algorithm: %s", algorithm))
return 1
return state.PushError("unsupported algorithm: %s", algorithm)
}
h.Write([]byte(data))
@ -107,8 +111,10 @@ func cryptoHash(state *luajit.State) int {
// Output format
outputFormat := "hex"
if state.GetTop() >= 3 && state.IsString(3) {
outputFormat = state.ToString(3)
if state.GetTop() >= 3 {
if format, err := state.SafeToString(3); err == nil {
outputFormat = format
}
}
switch outputFormat {
@ -125,17 +131,26 @@ func cryptoHash(state *luajit.State) int {
// cryptoHmac generates HMAC using various hash algorithms
func cryptoHmac(state *luajit.State) int {
if !state.IsString(1) || !state.IsString(2) || !state.IsString(3) {
state.PushString("hmac: expected (string data, string key, string algorithm)")
return 1
if err := state.CheckMinArgs(3); err != nil {
return state.PushError("hmac: %v", err)
}
data := state.ToString(1)
key := state.ToString(2)
algorithm := state.ToString(3)
data, err := state.SafeToString(1)
if err != nil {
return state.PushError("hmac: data must be string")
}
key, err := state.SafeToString(2)
if err != nil {
return state.PushError("hmac: key must be string")
}
algorithm, err := state.SafeToString(3)
if err != nil {
return state.PushError("hmac: algorithm must be string")
}
var h func() hash.Hash
switch algorithm {
case "md5":
h = md5.New
@ -146,8 +161,7 @@ func cryptoHmac(state *luajit.State) int {
case "sha512":
h = sha512.New
default:
state.PushString(fmt.Sprintf("unsupported algorithm: %s", algorithm))
return 1
return state.PushError("unsupported algorithm: %s", algorithm)
}
mac := hmac.New(h, []byte(key))
@ -156,8 +170,10 @@ func cryptoHmac(state *luajit.State) int {
// Output format
outputFormat := "hex"
if state.GetTop() >= 4 && state.IsString(4) {
outputFormat = state.ToString(4)
if state.GetTop() >= 4 {
if format, err := state.SafeToString(4); err == nil {
outputFormat = format
}
}
switch outputFormat {
@ -175,10 +191,8 @@ func cryptoHmac(state *luajit.State) int {
// cryptoUuid generates a random UUID v4
func cryptoUuid(state *luajit.State) int {
uuid := make([]byte, 16)
_, err := rand.Read(uuid)
if err != nil {
state.PushString(fmt.Sprintf("uuid: generation error: %v", err))
return 1
if _, err := rand.Read(uuid); err != nil {
return state.PushError("uuid: generation error: %v", err)
}
// Set version (4) and variant (RFC 4122)
@ -194,15 +208,17 @@ func cryptoUuid(state *luajit.State) int {
// cryptoRandomBytes generates random bytes
func cryptoRandomBytes(state *luajit.State) int {
if !state.IsNumber(1) {
state.PushString("random_bytes: expected (number length)")
return 1
if err := state.CheckMinArgs(1); err != nil {
return state.PushError("random_bytes: %v", err)
}
length, err := state.SafeToNumber(1)
if err != nil {
return state.PushError("random_bytes: length must be number")
}
length := int(state.ToNumber(1))
if length <= 0 {
state.PushString("random_bytes: length must be positive")
return 1
return state.PushError("random_bytes: length must be positive")
}
// Check if secure
@ -211,13 +227,11 @@ func cryptoRandomBytes(state *luajit.State) int {
secure = state.ToBoolean(2)
}
bytes := make([]byte, length)
bytes := make([]byte, int(length))
if secure {
_, err := rand.Read(bytes)
if err != nil {
state.PushString(fmt.Sprintf("random_bytes: error: %v", err))
return 1
if _, err := rand.Read(bytes); err != nil {
return state.PushError("random_bytes: error: %v", err)
}
} else {
stateRngsMu.Lock()
@ -225,8 +239,7 @@ func cryptoRandomBytes(state *luajit.State) int {
stateRngsMu.Unlock()
if !ok {
state.PushString("random_bytes: RNG not initialized")
return 1
return state.PushError("random_bytes: RNG not initialized")
}
for i := range bytes {
@ -236,8 +249,10 @@ func cryptoRandomBytes(state *luajit.State) int {
// Output format
outputFormat := "binary"
if state.GetTop() >= 3 && state.IsString(3) {
outputFormat = state.ToString(3)
if state.GetTop() >= 3 {
if format, err := state.SafeToString(3); err == nil {
outputFormat = format
}
}
switch outputFormat {
@ -254,17 +269,25 @@ func cryptoRandomBytes(state *luajit.State) int {
// cryptoRandomInt generates a random integer in range [min, max]
func cryptoRandomInt(state *luajit.State) int {
if !state.IsNumber(1) || !state.IsNumber(2) {
state.PushString("random_int: expected (number min, number max)")
return 1
if err := state.CheckMinArgs(2); err != nil {
return state.PushError("random_int: %v", err)
}
min := int64(state.ToNumber(1))
max := int64(state.ToNumber(2))
minVal, err := state.SafeToNumber(1)
if err != nil {
return state.PushError("random_int: min must be number")
}
maxVal, err := state.SafeToNumber(2)
if err != nil {
return state.PushError("random_int: max must be number")
}
min := int64(minVal)
max := int64(maxVal)
if max <= min {
state.PushString("random_int: max must be greater than min")
return 1
return state.PushError("random_int: max must be greater than min")
}
// Check if secure
@ -274,15 +297,12 @@ func cryptoRandomInt(state *luajit.State) int {
}
range_size := max - min + 1
var result int64
if secure {
bytes := make([]byte, 8)
_, err := rand.Read(bytes)
if err != nil {
state.PushString(fmt.Sprintf("random_int: error: %v", err))
return 1
if _, err := rand.Read(bytes); err != nil {
return state.PushError("random_int: error: %v", err)
}
val := binary.BigEndian.Uint64(bytes)
@ -293,8 +313,7 @@ func cryptoRandomInt(state *luajit.State) int {
stateRngsMu.Unlock()
if !ok {
state.PushString("random_int: RNG not initialized")
return 1
return state.PushError("random_int: RNG not initialized")
}
result = min + int64(stateRng.Uint64()%uint64(range_size))
@ -315,10 +334,8 @@ func cryptoRandom(state *luajit.State) int {
if numArgs == 0 {
if secure {
bytes := make([]byte, 8)
_, err := rand.Read(bytes)
if err != nil {
state.PushString(fmt.Sprintf("random: error: %v", err))
return 1
if _, err := rand.Read(bytes); err != nil {
return state.PushError("random: error: %v", err)
}
val := binary.BigEndian.Uint64(bytes)
state.PushNumber(float64(val) / float64(math.MaxUint64))
@ -328,8 +345,7 @@ func cryptoRandom(state *luajit.State) int {
stateRngsMu.Unlock()
if !ok {
state.PushString("random: RNG not initialized")
return 1
return state.PushError("random: RNG not initialized")
}
state.PushNumber(float64(stateRng.Uint64()) / float64(math.MaxUint64))
@ -341,8 +357,7 @@ func cryptoRandom(state *luajit.State) int {
if numArgs == 1 && state.IsNumber(1) {
n := int64(state.ToNumber(1))
if n < 1 {
state.PushString("random: upper bound must be >= 1")
return 1
return state.PushError("random: upper bound must be >= 1")
}
state.PushNumber(1) // min
@ -357,21 +372,23 @@ func cryptoRandom(state *luajit.State) int {
return cryptoRandomInt(state)
}
state.PushString("random: invalid arguments")
return 1
return state.PushError("random: invalid arguments")
}
// cryptoRandomSeed sets seed for non-secure RNG
func cryptoRandomSeed(state *luajit.State) int {
if !state.IsNumber(1) {
state.PushString("randomseed: expected (number seed)")
return 1
if err := state.CheckExactArgs(1); err != nil {
return state.PushError("randomseed: %v", err)
}
seed := uint64(state.ToNumber(1))
seed, err := state.SafeToNumber(1)
if err != nil {
return state.PushError("randomseed: seed must be number")
}
seedVal := uint64(seed)
stateRngsMu.Lock()
stateRngs[state] = mrand.NewPCG(seed, seed>>32)
stateRngs[state] = mrand.NewPCG(seedVal, seedVal>>32)
stateRngsMu.Unlock()
return 0

View File

@ -71,7 +71,7 @@ var (
// precompileModule compiles a module's code to bytecode once
func precompileModule(m *ModuleInfo) {
m.Once.Do(func() {
tempState := luajit.New()
tempState := luajit.New(true) // Explicitly open standard libraries
if tempState == nil {
logger.Fatalf("Failed to create temp Lua state for %s module compilation", m.Name)
return
@ -105,18 +105,14 @@ func loadModule(state *luajit.State, m *ModuleInfo, verbose bool) error {
logger.Debugf("Loading %s.lua from precompiled bytecode", m.Name)
}
if err := state.LoadBytecode(*bytecode, m.Name+".lua"); err != nil {
return err
}
if m.DefinesGlobal {
// Module defines its own globals, just run it
if err := state.RunBytecode(); err != nil {
if err := state.LoadAndRunBytecode(*bytecode, m.Name+".lua"); err != nil {
return err
}
} else {
// Module returns a table, capture and set as global
if err := state.RunBytecodeWithResults(1); err != nil {
if err := state.LoadAndRunBytecodeWithResults(*bytecode, m.Name+".lua", 1); err != nil {
return err
}
state.SetGlobal(m.Name)

View File

@ -221,12 +221,17 @@ func CleanupEnv() error {
// envGet Lua function to get an environment variable
func envGet(state *luajit.State) int {
if !state.IsString(1) {
if err := state.CheckExactArgs(1); err != nil {
state.PushNil()
return 1
}
key, err := state.SafeToString(1)
if err != nil {
state.PushNil()
return 1
}
key := state.ToString(1)
if value, exists := globalEnvManager.Get(key); exists {
if err := state.PushValue(value); err != nil {
state.PushNil()
@ -239,13 +244,22 @@ func envGet(state *luajit.State) int {
// envSet Lua function to set an environment variable
func envSet(state *luajit.State) int {
if !state.IsString(1) || !state.IsString(2) {
if err := state.CheckExactArgs(2); err != nil {
state.PushBoolean(false)
return 1
}
key := state.ToString(1)
value := state.ToString(2)
key, err := state.SafeToString(1)
if err != nil {
state.PushBoolean(false)
return 1
}
value, err := state.SafeToString(2)
if err != nil {
state.PushBoolean(false)
return 1
}
globalEnvManager.Set(key, value)
state.PushBoolean(true)
@ -255,11 +269,9 @@ func envSet(state *luajit.State) int {
// envGetAll Lua function to get all environment variables
func envGetAll(state *luajit.State) int {
vars := globalEnvManager.GetAll()
if err := state.PushTable(vars); err != nil {
if err := state.PushValue(vars); err != nil {
state.PushNil()
}
return 1
}

View File

@ -112,23 +112,24 @@ func getCacheKey(fullPath string, modTime time.Time) string {
// fsReadFile reads a file and returns its contents
func fsReadFile(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.read_file: path must be a string")
return -1
if err := state.CheckExactArgs(1); err != nil {
return state.PushError("fs.read_file: %v", err)
}
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.read_file: path must be string")
}
path := state.ToString(1)
fullPath, err := ResolvePath(path)
if err != nil {
state.PushString("fs.read_file: " + err.Error())
return -1
return state.PushError("fs.read_file: %v", err)
}
// Get file info for cache key and validation
info, err := os.Stat(fullPath)
if err != nil {
state.PushString("fs.read_file: " + err.Error())
return -1
return state.PushError("fs.read_file: %v", err)
}
// Create cache key with path and modification time
@ -154,8 +155,7 @@ func fsReadFile(state *luajit.State) int {
stats.misses++
data, err := os.ReadFile(fullPath)
if err != nil {
state.PushString("fs.read_file: " + err.Error())
return -1
return state.PushError("fs.read_file: %v", err)
}
// Compress and cache the data
@ -170,41 +170,33 @@ func fsReadFile(state *luajit.State) int {
// fsWriteFile writes data to a file
func fsWriteFile(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.write_file: path must be a string")
return -1
if err := state.CheckExactArgs(2); err != nil {
return state.PushError("fs.write_file: %v", err)
}
path := state.ToString(1)
if !state.IsString(2) {
state.PushString("fs.write_file: content must be a string")
return -1
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.write_file: path must be string")
}
content, err := state.SafeToString(2)
if err != nil {
return state.PushError("fs.write_file: content must be string")
}
content := state.ToString(2)
fullPath, err := ResolvePath(path)
if err != nil {
state.PushString("fs.write_file: " + err.Error())
return -1
return state.PushError("fs.write_file: %v", err)
}
// Ensure the directory exists
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil {
state.PushString("fs.write_file: failed to create directory: " + err.Error())
return -1
return state.PushError("fs.write_file: failed to create directory: %v", err)
}
err = os.WriteFile(fullPath, []byte(content), 0644)
if err != nil {
state.PushString("fs.write_file: " + err.Error())
return -1
}
// Invalidate cache entries for this file path
if fileCache != nil {
// We can't easily iterate through cache keys, so we'll let the cache
// naturally expire old entries when the file is read again
if err := os.WriteFile(fullPath, []byte(content), 0644); err != nil {
return state.PushError("fs.write_file: %v", err)
}
state.PushBoolean(true)
@ -213,42 +205,39 @@ func fsWriteFile(state *luajit.State) int {
// fsAppendFile appends data to a file
func fsAppendFile(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.append_file: path must be a string")
return -1
if err := state.CheckExactArgs(2); err != nil {
return state.PushError("fs.append_file: %v", err)
}
path := state.ToString(1)
if !state.IsString(2) {
state.PushString("fs.append_file: content must be a string")
return -1
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.append_file: path must be string")
}
content, err := state.SafeToString(2)
if err != nil {
return state.PushError("fs.append_file: content must be string")
}
content := state.ToString(2)
fullPath, err := ResolvePath(path)
if err != nil {
state.PushString("fs.append_file: " + err.Error())
return -1
return state.PushError("fs.append_file: %v", err)
}
// Ensure the directory exists
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil {
state.PushString("fs.append_file: failed to create directory: " + err.Error())
return -1
return state.PushError("fs.append_file: failed to create directory: %v", err)
}
file, err := os.OpenFile(fullPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
state.PushString("fs.append_file: " + err.Error())
return -1
return state.PushError("fs.append_file: %v", err)
}
defer file.Close()
_, err = file.Write([]byte(content))
if err != nil {
state.PushString("fs.append_file: " + err.Error())
return -1
if _, err = file.Write([]byte(content)); err != nil {
return state.PushError("fs.append_file: %v", err)
}
state.PushBoolean(true)
@ -257,16 +246,18 @@ func fsAppendFile(state *luajit.State) int {
// fsExists checks if a file or directory exists
func fsExists(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.exists: path must be a string")
return -1
if err := state.CheckExactArgs(1); err != nil {
return state.PushError("fs.exists: %v", err)
}
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.exists: path must be string")
}
path := state.ToString(1)
fullPath, err := ResolvePath(path)
if err != nil {
state.PushString("fs.exists: " + err.Error())
return -1
return state.PushError("fs.exists: %v", err)
}
_, err = os.Stat(fullPath)
@ -276,34 +267,32 @@ func fsExists(state *luajit.State) int {
// fsRemoveFile removes a file
func fsRemoveFile(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.remove_file: path must be a string")
return -1
if err := state.CheckExactArgs(1); err != nil {
return state.PushError("fs.remove_file: %v", err)
}
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.remove_file: path must be string")
}
path := state.ToString(1)
fullPath, err := ResolvePath(path)
if err != nil {
state.PushString("fs.remove_file: " + err.Error())
return -1
return state.PushError("fs.remove_file: %v", err)
}
// Check if it's a directory
info, err := os.Stat(fullPath)
if err != nil {
state.PushString("fs.remove_file: " + err.Error())
return -1
return state.PushError("fs.remove_file: %v", err)
}
if info.IsDir() {
state.PushString("fs.remove_file: cannot remove directory, use remove_dir instead")
return -1
return state.PushError("fs.remove_file: cannot remove directory, use remove_dir instead")
}
err = os.Remove(fullPath)
if err != nil {
state.PushString("fs.remove_file: " + err.Error())
return -1
if err := os.Remove(fullPath); err != nil {
return state.PushError("fs.remove_file: %v", err)
}
state.PushBoolean(true)
@ -312,67 +301,65 @@ func fsRemoveFile(state *luajit.State) int {
// fsGetInfo gets information about a file
func fsGetInfo(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.get_info: path must be a string")
return -1
if err := state.CheckExactArgs(1); err != nil {
return state.PushError("fs.get_info: %v", err)
}
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.get_info: path must be string")
}
path := state.ToString(1)
fullPath, err := ResolvePath(path)
if err != nil {
state.PushString("fs.get_info: " + err.Error())
return -1
return state.PushError("fs.get_info: %v", err)
}
info, err := os.Stat(fullPath)
if err != nil {
state.PushString("fs.get_info: " + err.Error())
return -1
return state.PushError("fs.get_info: %v", err)
}
state.NewTable()
fileInfo := map[string]any{
"name": info.Name(),
"size": info.Size(),
"mode": int(info.Mode()),
"mod_time": info.ModTime().Unix(),
"is_dir": info.IsDir(),
}
state.PushString(info.Name())
state.SetField(-2, "name")
state.PushNumber(float64(info.Size()))
state.SetField(-2, "size")
state.PushNumber(float64(info.Mode()))
state.SetField(-2, "mode")
state.PushNumber(float64(info.ModTime().Unix()))
state.SetField(-2, "mod_time")
state.PushBoolean(info.IsDir())
state.SetField(-2, "is_dir")
if err := state.PushValue(fileInfo); err != nil {
return state.PushError("fs.get_info: %v", err)
}
return 1
}
// fsMakeDir creates a directory
func fsMakeDir(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.make_dir: path must be a string")
return -1
if err := state.CheckMinArgs(1); err != nil {
return state.PushError("fs.make_dir: %v", err)
}
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.make_dir: path must be string")
}
path := state.ToString(1)
perm := os.FileMode(0755)
if state.GetTop() >= 2 && state.IsNumber(2) {
perm = os.FileMode(state.ToNumber(2))
if state.GetTop() >= 2 {
if permVal, err := state.SafeToNumber(2); err == nil {
perm = os.FileMode(permVal)
}
}
fullPath, err := ResolvePath(path)
if err != nil {
state.PushString("fs.make_dir: " + err.Error())
return -1
return state.PushError("fs.make_dir: %v", err)
}
err = os.MkdirAll(fullPath, perm)
if err != nil {
state.PushString("fs.make_dir: " + err.Error())
return -1
if err := os.MkdirAll(fullPath, perm); err != nil {
return state.PushError("fs.make_dir: %v", err)
}
state.PushBoolean(true)
@ -381,41 +368,42 @@ func fsMakeDir(state *luajit.State) int {
// fsListDir lists the contents of a directory
func fsListDir(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.list_dir: path must be a string")
return -1
if err := state.CheckExactArgs(1); err != nil {
return state.PushError("fs.list_dir: %v", err)
}
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.list_dir: path must be string")
}
path := state.ToString(1)
fullPath, err := ResolvePath(path)
if err != nil {
state.PushString("fs.list_dir: " + err.Error())
return -1
return state.PushError("fs.list_dir: %v", err)
}
info, err := os.Stat(fullPath)
if err != nil {
state.PushString("fs.list_dir: " + err.Error())
return -1
return state.PushError("fs.list_dir: %v", err)
}
if !info.IsDir() {
state.PushString("fs.list_dir: not a directory")
return -1
return state.PushError("fs.list_dir: not a directory")
}
files, err := os.ReadDir(fullPath)
if err != nil {
state.PushString("fs.list_dir: " + err.Error())
return -1
return state.PushError("fs.list_dir: %v", err)
}
state.NewTable()
// Create array of filenames
filenames := make([]string, len(files))
for i, file := range files {
state.PushNumber(float64(i + 1))
state.PushString(file.Name())
state.SetTable(-3)
filenames[i] = file.Name()
}
if err := state.PushValue(filenames); err != nil {
return state.PushError("fs.list_dir: %v", err)
}
return 1
@ -423,11 +411,14 @@ func fsListDir(state *luajit.State) int {
// fsRemoveDir removes a directory
func fsRemoveDir(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.remove_dir: path must be a string")
return -1
if err := state.CheckMinArgs(1); err != nil {
return state.PushError("fs.remove_dir: %v", err)
}
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.remove_dir: path must be string")
}
path := state.ToString(1)
recursive := false
if state.GetTop() >= 2 {
@ -436,19 +427,16 @@ func fsRemoveDir(state *luajit.State) int {
fullPath, err := ResolvePath(path)
if err != nil {
state.PushString("fs.remove_dir: " + err.Error())
return -1
return state.PushError("fs.remove_dir: %v", err)
}
info, err := os.Stat(fullPath)
if err != nil {
state.PushString("fs.remove_dir: " + err.Error())
return -1
return state.PushError("fs.remove_dir: %v", err)
}
if !info.IsDir() {
state.PushString("fs.remove_dir: not a directory")
return -1
return state.PushError("fs.remove_dir: not a directory")
}
if recursive {
@ -458,8 +446,7 @@ func fsRemoveDir(state *luajit.State) int {
}
if err != nil {
state.PushString("fs.remove_dir: " + err.Error())
return -1
return state.PushError("fs.remove_dir: %v", err)
}
state.PushBoolean(true)
@ -468,19 +455,17 @@ func fsRemoveDir(state *luajit.State) int {
// fsJoinPaths joins path components
func fsJoinPaths(state *luajit.State) int {
nargs := state.GetTop()
if nargs < 1 {
state.PushString("fs.join_paths: at least one path component required")
return -1
if err := state.CheckMinArgs(1); err != nil {
return state.PushError("fs.join_paths: %v", err)
}
components := make([]string, nargs)
for i := 1; i <= nargs; i++ {
if !state.IsString(i) {
state.PushString("fs.join_paths: all arguments must be strings")
return -1
components := make([]string, state.GetTop())
for i := 1; i <= state.GetTop(); i++ {
comp, err := state.SafeToString(i)
if err != nil {
return state.PushError("fs.join_paths: all arguments must be strings")
}
components[i-1] = state.ToString(i)
components[i-1] = comp
}
result := filepath.Join(components...)
@ -492,11 +477,14 @@ func fsJoinPaths(state *luajit.State) int {
// fsDirName returns the directory portion of a path
func fsDirName(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.dir_name: path must be a string")
return -1
if err := state.CheckExactArgs(1); err != nil {
return state.PushError("fs.dir_name: %v", err)
}
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.dir_name: path must be string")
}
path := state.ToString(1)
dir := filepath.Dir(path)
dir = strings.ReplaceAll(dir, "\\", "/")
@ -507,28 +495,32 @@ func fsDirName(state *luajit.State) int {
// fsBaseName returns the file name portion of a path
func fsBaseName(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.base_name: path must be a string")
return -1
if err := state.CheckExactArgs(1); err != nil {
return state.PushError("fs.base_name: %v", err)
}
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.base_name: path must be string")
}
path := state.ToString(1)
base := filepath.Base(path)
state.PushString(base)
return 1
}
// fsExtension returns the file extension
func fsExtension(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("fs.extension: path must be a string")
return -1
if err := state.CheckExactArgs(1); err != nil {
return state.PushError("fs.extension: %v", err)
}
path, err := state.SafeToString(1)
if err != nil {
return state.PushError("fs.extension: path must be string")
}
path := state.ToString(1)
ext := filepath.Ext(path)
state.PushString(ext)
return 1
}

View File

@ -94,25 +94,26 @@ func ApplyResponse(resp *Response, ctx *fasthttp.RequestCtx) {
// httpRequest makes an HTTP request and returns the result to Lua
func httpRequest(state *luajit.State) int {
// Get method (required)
if !state.IsString(1) {
state.PushString("http.client.request: method must be a string")
return -1
if err := state.CheckMinArgs(2); err != nil {
return state.PushError("http.client.request: %v", err)
}
method := strings.ToUpper(state.ToString(1))
// Get URL (required)
if !state.IsString(2) {
state.PushString("http.client.request: url must be a string")
return -1
// Get method and URL
method, err := state.SafeToString(1)
if err != nil {
return state.PushError("http.client.request: method must be string")
}
method = strings.ToUpper(method)
urlStr, err := state.SafeToString(2)
if err != nil {
return state.PushError("http.client.request: url must be string")
}
urlStr := state.ToString(2)
// Parse URL to check if it's valid
parsedURL, err := url.Parse(urlStr)
if err != nil {
state.PushString("Invalid URL: " + err.Error())
return -1
return state.PushError("Invalid URL: %v", err)
}
// Get client configuration
@ -120,8 +121,7 @@ func httpRequest(state *luajit.State) int {
// Check if remote connections are allowed
if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") {
state.PushString("Remote connections are not allowed")
return -1
return state.PushError("Remote connections are not allowed")
}
// Use bytebufferpool for request and response
@ -139,13 +139,13 @@ func httpRequest(state *luajit.State) int {
if state.GetTop() >= 3 && !state.IsNil(3) {
if state.IsString(3) {
// String body
req.SetBodyString(state.ToString(3))
bodyStr, _ := state.SafeToString(3)
req.SetBodyString(bodyStr)
} else if state.IsTable(3) {
// Table body - convert to JSON
luaTable, err := state.ToTable(3)
luaTable, err := state.SafeToTable(3)
if err != nil {
state.PushString("Failed to parse body table: " + err.Error())
return -1
return state.PushError("Failed to parse body table: %v", err)
}
// Use bytebufferpool for JSON serialization
@ -153,42 +153,37 @@ func httpRequest(state *luajit.State) int {
defer bytebufferpool.Put(buf)
if err := json.NewEncoder(buf).Encode(luaTable); err != nil {
state.PushString("Failed to convert body to JSON: " + err.Error())
return -1
return state.PushError("Failed to convert body to JSON: %v", err)
}
req.SetBody(buf.Bytes())
req.Header.SetContentType("application/json")
} else {
state.PushString("Body must be a string or table")
return -1
return state.PushError("Body must be a string or table")
}
}
// Process options (headers, timeout, etc.)
timeout := config.DefaultTimeout
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsTable(4) {
// Process headers
state.GetField(4, "headers")
if state.IsTable(-1) {
// Iterate through headers
state.PushNil() // Start iteration
for state.Next(-2) {
// Stack now has key at -2 and value at -1
if state.IsString(-2) && state.IsString(-1) {
headerName := state.ToString(-2)
headerValue := state.ToString(-1)
req.Header.Set(headerName, headerValue)
// Process headers using ForEachTableKV
if headers, ok := state.GetFieldTable(4, "headers"); ok {
if headerMap, ok := headers.(map[string]string); ok {
for name, value := range headerMap {
req.Header.Set(name, value)
}
} else if headerMapAny, ok := headers.(map[string]any); ok {
for name, value := range headerMapAny {
if valueStr, ok := value.(string); ok {
req.Header.Set(name, valueStr)
}
}
state.Pop(1) // Pop value, leave key for next iteration
}
}
state.Pop(1) // Pop headers table
// Get timeout
state.GetField(4, "timeout")
if state.IsNumber(-1) {
requestTimeout := time.Duration(state.ToNumber(-1)) * time.Second
if timeoutVal := state.GetFieldNumber(4, "timeout", 0); timeoutVal > 0 {
requestTimeout := time.Duration(timeoutVal) * time.Second
// Apply max timeout if configured
if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout {
@ -197,38 +192,34 @@ func httpRequest(state *luajit.State) int {
timeout = requestTimeout
}
}
state.Pop(1) // Pop timeout
// Process query parameters
state.GetField(4, "query")
if state.IsTable(-1) {
// Create URL args
if query, ok := state.GetFieldTable(4, "query"); ok {
args := req.URI().QueryArgs()
// Iterate through query params
state.PushNil() // Start iteration
for state.Next(-2) {
if state.IsString(-2) {
paramName := state.ToString(-2)
// Handle different value types
if state.IsString(-1) {
args.Add(paramName, state.ToString(-1))
} else if state.IsNumber(-1) {
args.Add(paramName, strings.TrimRight(strings.TrimRight(
state.ToString(-1), "0"), "."))
} else if state.IsBoolean(-1) {
if state.ToBoolean(-1) {
args.Add(paramName, "true")
if queryMap, ok := query.(map[string]string); ok {
for name, value := range queryMap {
args.Add(name, value)
}
} else if queryMapAny, ok := query.(map[string]any); ok {
for name, value := range queryMapAny {
switch v := value.(type) {
case string:
args.Add(name, v)
case int:
args.Add(name, fmt.Sprintf("%d", v))
case float64:
args.Add(name, strings.TrimRight(strings.TrimRight(fmt.Sprintf("%.6f", v), "0"), "."))
case bool:
if v {
args.Add(name, "true")
} else {
args.Add(paramName, "false")
args.Add(name, "false")
}
}
}
state.Pop(1) // Pop value, leave key for next iteration
}
}
state.Pop(1) // Pop query table
}
// Create context with timeout
@ -242,26 +233,18 @@ func httpRequest(state *luajit.State) int {
if errors.Is(err, fasthttp.ErrTimeout) {
errStr = "Request timed out after " + timeout.String()
}
state.PushString(errStr)
return -1
return state.PushError("%s", errStr)
}
// Create response table
state.NewTable()
// Create response using TableBuilder
builder := state.NewTableBuilder()
// Set status code
state.PushNumber(float64(resp.StatusCode()))
state.SetField(-2, "status")
// Set status text
statusText := fasthttp.StatusMessage(resp.StatusCode())
state.PushString(statusText)
state.SetField(-2, "status_text")
// Set status code and text
builder.SetNumber("status", float64(resp.StatusCode()))
builder.SetString("status_text", fasthttp.StatusMessage(resp.StatusCode()))
// Set body
var respBody []byte
// Apply size limits to response
if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize {
// Make a limited copy
respBody = make([]byte, config.MaxResponseSize)
@ -270,32 +253,28 @@ func httpRequest(state *luajit.State) int {
respBody = resp.Body()
}
state.PushString(string(respBody))
state.SetField(-2, "body")
builder.SetString("body", string(respBody))
// Parse body as JSON if content type is application/json
contentType := string(resp.Header.ContentType())
if strings.Contains(contentType, "application/json") {
var jsonData any
if err := json.Unmarshal(respBody, &jsonData); err == nil {
if err := state.PushValue(jsonData); err == nil {
state.SetField(-2, "json")
}
builder.SetTable("json", jsonData)
}
}
// Set headers
state.NewTable()
headers := make(map[string]string)
resp.Header.VisitAll(func(key, value []byte) {
state.PushString(string(value))
state.SetField(-2, string(key))
headers[string(key)] = string(value)
})
state.SetField(-2, "headers")
builder.SetTable("headers", headers)
// Create ok field (true if status code is 2xx)
state.PushBoolean(resp.StatusCode() >= 200 && resp.StatusCode() < 300)
state.SetField(-2, "ok")
builder.SetBool("ok", resp.StatusCode() >= 200 && resp.StatusCode() < 300)
builder.Build()
return 1
}
@ -303,8 +282,10 @@ func httpRequest(state *luajit.State) int {
func generateToken(state *luajit.State) int {
// Get the length from the Lua arguments (default to 32)
length := 32
if state.GetTop() >= 1 && state.IsNumber(1) {
length = int(state.ToNumber(1))
if state.GetTop() >= 1 {
if lengthVal, err := state.SafeToNumber(1); err == nil {
length = int(lengthVal)
}
}
// Enforce minimum length for security

View File

@ -1,8 +1,6 @@
package runner
import (
"fmt"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
"github.com/alexedwards/argon2id"
)
@ -20,12 +18,14 @@ func RegisterPasswordFunctions(state *luajit.State) error {
// passwordHash implements the Argon2id password hashing using alexedwards/argon2id
func passwordHash(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("password_hash error: expected string password")
return 1
if err := state.CheckMinArgs(1); err != nil {
return state.PushError("password_hash: %v", err)
}
password := state.ToString(1)
password, err := state.SafeToString(1)
if err != nil {
return state.PushError("password_hash: password must be string")
}
params := &argon2id.Params{
Memory: 128 * 1024,
@ -35,42 +35,32 @@ func passwordHash(state *luajit.State) int {
KeyLength: 32,
}
if state.IsTable(2) {
state.GetField(2, "memory")
if state.IsNumber(-1) {
params.Memory = max(uint32(state.ToNumber(-1)), 8*1024)
if state.GetTop() >= 2 && state.IsTable(2) {
// Use new field getters with validation
if memory := state.GetFieldNumber(2, "memory", 0); memory > 0 {
params.Memory = max(uint32(memory), 8*1024)
}
state.Pop(1)
state.GetField(2, "iterations")
if state.IsNumber(-1) {
params.Iterations = max(uint32(state.ToNumber(-1)), 1)
if iterations := state.GetFieldNumber(2, "iterations", 0); iterations > 0 {
params.Iterations = max(uint32(iterations), 1)
}
state.Pop(1)
state.GetField(2, "parallelism")
if state.IsNumber(-1) {
params.Parallelism = max(uint8(state.ToNumber(-1)), 1)
if parallelism := state.GetFieldNumber(2, "parallelism", 0); parallelism > 0 {
params.Parallelism = max(uint8(parallelism), 1)
}
state.Pop(1)
state.GetField(2, "salt_length")
if state.IsNumber(-1) {
params.SaltLength = max(uint32(state.ToNumber(-1)), 8)
if saltLength := state.GetFieldNumber(2, "salt_length", 0); saltLength > 0 {
params.SaltLength = max(uint32(saltLength), 8)
}
state.Pop(1)
state.GetField(2, "key_length")
if state.IsNumber(-1) {
params.KeyLength = max(uint32(state.ToNumber(-1)), 16)
if keyLength := state.GetFieldNumber(2, "key_length", 0); keyLength > 0 {
params.KeyLength = max(uint32(keyLength), 16)
}
state.Pop(1)
}
hash, err := argon2id.CreateHash(password, params)
if err != nil {
state.PushString(fmt.Sprintf("password_hash error: %v", err))
return 1
return state.PushError("password_hash: %v", err)
}
state.PushString(hash)
@ -79,13 +69,22 @@ func passwordHash(state *luajit.State) int {
// passwordVerify verifies a password against a hash
func passwordVerify(state *luajit.State) int {
if !state.IsString(1) || !state.IsString(2) {
if err := state.CheckExactArgs(2); err != nil {
state.PushBoolean(false)
return 1
}
password := state.ToString(1)
hash := state.ToString(2)
password, err := state.SafeToString(1)
if err != nil {
state.PushBoolean(false)
return 1
}
hash, err := state.SafeToString(2)
if err != nil {
state.PushBoolean(false)
return 1
}
match, err := argon2id.ComparePasswordAndHash(password, hash)
if err != nil {

View File

@ -156,7 +156,7 @@ func (r *Runner) createState(index int) (*State, error) {
logger.Debugf("Creating Lua state %d", index)
}
L := luajit.New()
L := luajit.New(true) // Explicitly open standard libraries
if L == nil {
return nil, errors.New("failed to create Lua state")
}

View File

@ -130,42 +130,30 @@ func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
// Execute runs a Lua script in the sandbox with the given context
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
state.GetGlobal("__execute_script")
if !state.IsFunction(-1) {
state.Pop(1)
return nil, ErrSandboxNotInitialized
}
// Use CallGlobal for cleaner function calling
results, err := state.CallGlobal("__execute_script",
func() any {
if err := state.LoadBytecode(bytecode, "script"); err != nil {
return nil
}
return "loaded_script"
}(),
ctx.Values)
if err := state.LoadBytecode(bytecode, "script"); err != nil {
state.Pop(1) // Pop the __execute_script function
return nil, fmt.Errorf("failed to load script: %w", err)
}
// Push context values
if err := state.PushTable(ctx.Values); err != nil {
state.Pop(2) // Pop bytecode and __execute_script
return nil, err
}
// Execute with 2 args, 1 result
if err := state.Call(2, 1); err != nil {
if err != nil {
return nil, fmt.Errorf("script execution failed: %w", err)
}
body, err := state.ToValue(-1)
state.Pop(1)
response := NewResponse()
if err == nil {
response.Body = body
if len(results) > 0 {
response.Body = results[0]
}
extractHTTPResponseData(state, response)
return response, nil
}
// extractResponseData pulls response info from the Lua state
// extractResponseData pulls response info from the Lua state using new API
func extractHTTPResponseData(state *luajit.State, response *Response) {
state.GetGlobal("__http_response")
if !state.IsTable(-1) {
@ -173,162 +161,113 @@ func extractHTTPResponseData(state *luajit.State, response *Response) {
return
}
// Extract status
state.GetField(-1, "status")
if state.IsNumber(-1) {
response.Status = int(state.ToNumber(-1))
}
state.Pop(1)
// Use new field getters with defaults
response.Status = int(state.GetFieldNumber(-1, "status", 200))
// Extract headers
state.GetField(-1, "headers")
if state.IsTable(-1) {
state.PushNil() // Start iteration
for state.Next(-2) {
if state.IsString(-2) && state.IsString(-1) {
key := state.ToString(-2)
value := state.ToString(-1)
response.Headers[key] = value
// Extract headers using ForEachTableKV
if headerTable, ok := state.GetFieldTable(-1, "headers"); ok {
if headers, ok := headerTable.(map[string]any); ok {
for k, v := range headers {
if str, ok := v.(string); ok {
response.Headers[k] = str
}
}
state.Pop(1)
}
}
state.Pop(1)
// Extract cookies
// Extract cookies using ForEachArray
state.GetField(-1, "cookies")
if state.IsTable(-1) {
length := state.GetTableLength(-1)
for i := 1; i <= length; i++ {
state.PushNumber(float64(i))
state.GetTable(-2)
if state.IsTable(-1) {
extractCookie(state, response)
state.ForEachArray(-1, func(i int, s *luajit.State) bool {
if s.IsTable(-1) {
extractCookie(s, response)
}
state.Pop(1)
}
return true
})
}
state.Pop(1)
// Extract metadata
state.GetField(-1, "metadata")
if state.IsTable(-1) {
table, err := state.ToTable(-1)
if err == nil {
maps.Copy(response.Metadata, table)
if metadata, ok := state.GetFieldTable(-1, "metadata"); ok {
if metaMap, ok := metadata.(map[string]any); ok {
maps.Copy(response.Metadata, metaMap)
}
}
state.Pop(1)
// Extract session data
state.GetField(-1, "session")
if state.IsTable(-1) {
table, err := state.ToTable(-1)
if err == nil {
maps.Copy(response.SessionData, table)
if session, ok := state.GetFieldTable(-1, "session"); ok {
if sessMap, ok := session.(map[string]any); ok {
maps.Copy(response.SessionData, sessMap)
}
}
state.Pop(1)
state.Pop(1) // Pop __http_response
}
// extractCookie pulls cookie data from the current table on the stack
// extractCookie pulls cookie data from the current table on the stack using new API
func extractCookie(state *luajit.State, response *Response) {
cookie := fasthttp.AcquireCookie()
// Get name (required)
state.GetField(-1, "name")
if !state.IsString(-1) {
state.Pop(1)
// Use new field getters with defaults
name := state.GetFieldString(-1, "name", "")
if name == "" {
fasthttp.ReleaseCookie(cookie)
return
}
cookie.SetKey(state.ToString(-1))
state.Pop(1)
// Get value
state.GetField(-1, "value")
if state.IsString(-1) {
cookie.SetValue(state.ToString(-1))
}
state.Pop(1)
// Get path
state.GetField(-1, "path")
if state.IsString(-1) {
cookie.SetPath(state.ToString(-1))
} else {
cookie.SetPath("/") // Default
}
state.Pop(1)
// Get domain
state.GetField(-1, "domain")
if state.IsString(-1) {
cookie.SetDomain(state.ToString(-1))
}
state.Pop(1)
// Get other parameters
state.GetField(-1, "http_only")
if state.IsBoolean(-1) {
cookie.SetHTTPOnly(state.ToBoolean(-1))
}
state.Pop(1)
state.GetField(-1, "secure")
if state.IsBoolean(-1) {
cookie.SetSecure(state.ToBoolean(-1))
}
state.Pop(1)
state.GetField(-1, "max_age")
if state.IsNumber(-1) {
cookie.SetMaxAge(int(state.ToNumber(-1)))
}
state.Pop(1)
cookie.SetKey(name)
cookie.SetValue(state.GetFieldString(-1, "value", ""))
cookie.SetPath(state.GetFieldString(-1, "path", "/"))
cookie.SetDomain(state.GetFieldString(-1, "domain", ""))
cookie.SetHTTPOnly(state.GetFieldBool(-1, "http_only", false))
cookie.SetSecure(state.GetFieldBool(-1, "secure", false))
cookie.SetMaxAge(int(state.GetFieldNumber(-1, "max_age", 0)))
response.Cookies = append(response.Cookies, cookie)
}
// jsonMarshal converts a Lua value to a JSON string
// jsonMarshal converts a Lua value to a JSON string with validation
func jsonMarshal(state *luajit.State) int {
value, err := state.ToValue(1)
if err := state.CheckExactArgs(1); err != nil {
return state.PushError("json marshal: %v", err)
}
value, err := state.SafeToTable(1)
if err != nil {
state.PushString(fmt.Sprintf("json marshal error: %v", err))
return 1
// Try as generic value if not a table
value, err = state.ToValue(1)
if err != nil {
return state.PushError("json marshal error: %v", err)
}
}
bytes, err := json.Marshal(value)
if err != nil {
state.PushString(fmt.Sprintf("json marshal error: %v", err))
return 1
return state.PushError("json marshal error: %v", err)
}
state.PushString(string(bytes))
return 1
}
// jsonUnmarshal converts a JSON string to a Lua value
// jsonUnmarshal converts a JSON string to a Lua value with validation
func jsonUnmarshal(state *luajit.State) int {
if !state.IsString(1) {
state.PushString("json unmarshal error: expected string")
return 1
if err := state.CheckExactArgs(1); err != nil {
return state.PushError("json unmarshal: %v", err)
}
jsonStr, err := state.SafeToString(1)
if err != nil {
return state.PushError("json unmarshal: expected string, got %s", state.GetType(1))
}
jsonStr := state.ToString(1)
var value any
err := json.Unmarshal([]byte(jsonStr), &value)
if err != nil {
state.PushString(fmt.Sprintf("json unmarshal error: %v", err))
return 1
if err := json.Unmarshal([]byte(jsonStr), &value); err != nil {
return state.PushError("json unmarshal error: %v", err)
}
if err := state.PushValue(value); err != nil {
state.PushString(fmt.Sprintf("json unmarshal error: %v", err))
return 1
return state.PushError("json unmarshal error: %v", err)
}
return 1
}

View File

@ -111,20 +111,24 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
// sqlQuery executes a SQL query and returns results
func sqlQuery(state *luajit.State) int {
// Get required parameters
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
state.PushString("sqlite.query: requires database name and query")
return -1
if err := state.CheckMinArgs(2); err != nil {
return state.PushError("sqlite.query: %v", err)
}
dbName := state.ToString(1)
query := state.ToString(2)
dbName, err := state.SafeToString(1)
if err != nil {
return state.PushError("sqlite.query: database name must be string")
}
query, err := state.SafeToString(2)
if err != nil {
return state.PushError("sqlite.query: query must be string")
}
// Get pool
pool, err := getPool(dbName)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
return state.PushError("sqlite.query: %v", err)
}
// Get connection with timeout
@ -133,8 +137,7 @@ func sqlQuery(state *luajit.State) int {
conn, err := pool.Take(ctx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: connection timeout: %s", err.Error()))
return -1
return state.PushError("sqlite.query: connection timeout: %v", err)
}
defer pool.Put(conn)
@ -145,8 +148,7 @@ func sqlQuery(state *luajit.State) int {
// Set up parameters if provided
if state.GetTop() >= 3 && !state.IsNil(3) {
if err := setupParams(state, 3, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
return state.PushError("sqlite.query: %v", err)
}
}
@ -182,19 +184,12 @@ func sqlQuery(state *luajit.State) int {
// Execute query
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
return state.PushError("sqlite.query: %v", err)
}
// Create result table
state.NewTable()
for i, row := range rows {
state.PushNumber(float64(i + 1))
if err := state.PushTable(row); err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
state.SetTable(-3)
// Create result using specific map type and PushValue
if err := state.PushValue(rows); err != nil {
return state.PushError("sqlite.query: %v", err)
}
return 1
@ -202,20 +197,24 @@ func sqlQuery(state *luajit.State) int {
// sqlExec executes a SQL statement without returning results
func sqlExec(state *luajit.State) int {
// Get required parameters
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
state.PushString("sqlite.exec: requires database name and query")
return -1
if err := state.CheckMinArgs(2); err != nil {
return state.PushError("sqlite.exec: %v", err)
}
dbName := state.ToString(1)
query := state.ToString(2)
dbName, err := state.SafeToString(1)
if err != nil {
return state.PushError("sqlite.exec: database name must be string")
}
query, err := state.SafeToString(2)
if err != nil {
return state.PushError("sqlite.exec: query must be string")
}
// Get pool
pool, err := getPool(dbName)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
return state.PushError("sqlite.exec: %v", err)
}
// Get connection with timeout
@ -224,8 +223,7 @@ func sqlExec(state *luajit.State) int {
conn, err := pool.Take(ctx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: connection timeout: %s", err.Error()))
return -1
return state.PushError("sqlite.exec: connection timeout: %v", err)
}
defer pool.Put(conn)
@ -235,8 +233,7 @@ func sqlExec(state *luajit.State) int {
// Fast path for multi-statement scripts
if strings.Contains(query, ";") && !hasParams {
if err := sqlitex.ExecScript(conn, query); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
return state.PushError("sqlite.exec: %v", err)
}
state.PushNumber(float64(conn.Changes()))
return 1
@ -245,8 +242,7 @@ func sqlExec(state *luajit.State) int {
// Fast path for simple queries with no parameters
if !hasParams {
if err := sqlitex.Execute(conn, query, nil); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
return state.PushError("sqlite.exec: %v", err)
}
state.PushNumber(float64(conn.Changes()))
return 1
@ -255,14 +251,12 @@ func sqlExec(state *luajit.State) int {
// Create execution options for parameterized query
var execOpts sqlitex.ExecOptions
if err := setupParams(state, 3, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
return state.PushError("sqlite.exec: %v", err)
}
// Execute with parameters
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
return state.PushError("sqlite.exec: %v", err)
}
// Return affected rows
@ -273,11 +267,17 @@ func sqlExec(state *luajit.State) int {
// setupParams configures execution options with parameters from Lua
func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error {
if state.IsTable(paramIndex) {
params, err := state.ToTable(paramIndex)
paramsAny, err := state.SafeToTable(paramIndex)
if err != nil {
return fmt.Errorf("invalid parameters: %w", err)
}
// Type assert to map[string]any
params, ok := paramsAny.(map[string]any)
if !ok {
return fmt.Errorf("parameters must be a table")
}
// Check for array-style params
if arr, ok := params[""]; ok {
if arrParams, ok := arr.([]any); ok {
@ -321,20 +321,24 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti
// sqlGetOne executes a query and returns only the first row
func sqlGetOne(state *luajit.State) int {
// Get required parameters
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
state.PushString("sqlite.get_one: requires database name and query")
return -1
if err := state.CheckMinArgs(2); err != nil {
return state.PushError("sqlite.get_one: %v", err)
}
dbName := state.ToString(1)
query := state.ToString(2)
dbName, err := state.SafeToString(1)
if err != nil {
return state.PushError("sqlite.get_one: database name must be string")
}
query, err := state.SafeToString(2)
if err != nil {
return state.PushError("sqlite.get_one: query must be string")
}
// Get pool
pool, err := getPool(dbName)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
return state.PushError("sqlite.get_one: %v", err)
}
// Get connection with timeout
@ -343,8 +347,7 @@ func sqlGetOne(state *luajit.State) int {
conn, err := pool.Take(ctx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: connection timeout: %s", err.Error()))
return -1
return state.PushError("sqlite.get_one: connection timeout: %v", err)
}
defer pool.Put(conn)
@ -355,8 +358,7 @@ func sqlGetOne(state *luajit.State) int {
// Set up parameters if provided
if state.GetTop() >= 3 && !state.IsNil(3) {
if err := setupParams(state, 3, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
return state.PushError("sqlite.get_one: %v", err)
}
}
@ -395,17 +397,15 @@ func sqlGetOne(state *luajit.State) int {
// Execute query
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
return state.PushError("sqlite.get_one: %v", err)
}
// Return result or nil if no rows
if result == nil {
state.PushNil()
} else {
if err := state.PushTable(result); err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
if err := state.PushValue(result); err != nil {
return state.PushError("sqlite.get_one: %v", err)
}
}

View File

@ -35,12 +35,17 @@ func RegisterUtilFunctions(state *luajit.State) error {
// htmlSpecialChars converts special characters to HTML entities
func htmlSpecialChars(state *luajit.State) int {
if !state.IsString(1) {
if err := state.CheckExactArgs(1); err != nil {
state.PushNil()
return 1
}
input, err := state.SafeToString(1)
if err != nil {
state.PushNil()
return 1
}
input := state.ToString(1)
result := html.EscapeString(input)
state.PushString(result)
return 1
@ -48,12 +53,17 @@ func htmlSpecialChars(state *luajit.State) int {
// htmlEntities is a more comprehensive version of htmlSpecialChars
func htmlEntities(state *luajit.State) int {
if !state.IsString(1) {
if err := state.CheckExactArgs(1); err != nil {
state.PushNil()
return 1
}
input, err := state.SafeToString(1)
if err != nil {
state.PushNil()
return 1
}
input := state.ToString(1)
// First use HTML escape for standard entities
result := html.EscapeString(input)
@ -86,12 +96,17 @@ func htmlEntities(state *luajit.State) int {
// base64Encode encodes a string to base64
func base64Encode(state *luajit.State) int {
if !state.IsString(1) {
if err := state.CheckExactArgs(1); err != nil {
state.PushNil()
return 1
}
input, err := state.SafeToString(1)
if err != nil {
state.PushNil()
return 1
}
input := state.ToString(1)
result := base64.StdEncoding.EncodeToString([]byte(input))
state.PushString(result)
return 1
@ -99,12 +114,17 @@ func base64Encode(state *luajit.State) int {
// base64Decode decodes a base64 string
func base64Decode(state *luajit.State) int {
if !state.IsString(1) {
if err := state.CheckExactArgs(1); err != nil {
state.PushNil()
return 1
}
input, err := state.SafeToString(1)
if err != nil {
state.PushNil()
return 1
}
input := state.ToString(1)
result, err := base64.StdEncoding.DecodeString(input)
if err != nil {
state.PushNil()