diff --git a/runner/crypto.go b/runner/crypto.go index 39ef1e3..5e21fc2 100644 --- a/runner/crypto.go +++ b/runner/crypto.go @@ -78,16 +78,21 @@ func CleanupCrypto(state *luajit.State) { // cryptoHash generates hash digests using various algorithms func cryptoHash(state *luajit.State) int { - if !state.IsString(1) || !state.IsString(2) { - state.PushString("hash: expected (string data, string algorithm)") - return 1 + if err := state.CheckMinArgs(2); err != nil { + return state.PushError("hash: %v", err) } - data := state.ToString(1) - algorithm := state.ToString(2) + data, err := state.SafeToString(1) + if err != nil { + return state.PushError("hash: data must be string") + } + + algorithm, err := state.SafeToString(2) + if err != nil { + return state.PushError("hash: algorithm must be string") + } var h hash.Hash - switch algorithm { case "md5": h = md5.New() @@ -98,8 +103,7 @@ func cryptoHash(state *luajit.State) int { case "sha512": h = sha512.New() default: - state.PushString(fmt.Sprintf("unsupported algorithm: %s", algorithm)) - return 1 + return state.PushError("unsupported algorithm: %s", algorithm) } h.Write([]byte(data)) @@ -107,8 +111,10 @@ func cryptoHash(state *luajit.State) int { // Output format outputFormat := "hex" - if state.GetTop() >= 3 && state.IsString(3) { - outputFormat = state.ToString(3) + if state.GetTop() >= 3 { + if format, err := state.SafeToString(3); err == nil { + outputFormat = format + } } switch outputFormat { @@ -125,17 +131,26 @@ func cryptoHash(state *luajit.State) int { // cryptoHmac generates HMAC using various hash algorithms func cryptoHmac(state *luajit.State) int { - if !state.IsString(1) || !state.IsString(2) || !state.IsString(3) { - state.PushString("hmac: expected (string data, string key, string algorithm)") - return 1 + if err := state.CheckMinArgs(3); err != nil { + return state.PushError("hmac: %v", err) } - data := state.ToString(1) - key := state.ToString(2) - algorithm := state.ToString(3) + data, err := state.SafeToString(1) + if err != nil { + return state.PushError("hmac: data must be string") + } + + key, err := state.SafeToString(2) + if err != nil { + return state.PushError("hmac: key must be string") + } + + algorithm, err := state.SafeToString(3) + if err != nil { + return state.PushError("hmac: algorithm must be string") + } var h func() hash.Hash - switch algorithm { case "md5": h = md5.New @@ -146,8 +161,7 @@ func cryptoHmac(state *luajit.State) int { case "sha512": h = sha512.New default: - state.PushString(fmt.Sprintf("unsupported algorithm: %s", algorithm)) - return 1 + return state.PushError("unsupported algorithm: %s", algorithm) } mac := hmac.New(h, []byte(key)) @@ -156,8 +170,10 @@ func cryptoHmac(state *luajit.State) int { // Output format outputFormat := "hex" - if state.GetTop() >= 4 && state.IsString(4) { - outputFormat = state.ToString(4) + if state.GetTop() >= 4 { + if format, err := state.SafeToString(4); err == nil { + outputFormat = format + } } switch outputFormat { @@ -175,10 +191,8 @@ func cryptoHmac(state *luajit.State) int { // cryptoUuid generates a random UUID v4 func cryptoUuid(state *luajit.State) int { uuid := make([]byte, 16) - _, err := rand.Read(uuid) - if err != nil { - state.PushString(fmt.Sprintf("uuid: generation error: %v", err)) - return 1 + if _, err := rand.Read(uuid); err != nil { + return state.PushError("uuid: generation error: %v", err) } // Set version (4) and variant (RFC 4122) @@ -194,15 +208,17 @@ func cryptoUuid(state *luajit.State) int { // cryptoRandomBytes generates random bytes func cryptoRandomBytes(state *luajit.State) int { - if !state.IsNumber(1) { - state.PushString("random_bytes: expected (number length)") - return 1 + if err := state.CheckMinArgs(1); err != nil { + return state.PushError("random_bytes: %v", err) + } + + length, err := state.SafeToNumber(1) + if err != nil { + return state.PushError("random_bytes: length must be number") } - length := int(state.ToNumber(1)) if length <= 0 { - state.PushString("random_bytes: length must be positive") - return 1 + return state.PushError("random_bytes: length must be positive") } // Check if secure @@ -211,13 +227,11 @@ func cryptoRandomBytes(state *luajit.State) int { secure = state.ToBoolean(2) } - bytes := make([]byte, length) + bytes := make([]byte, int(length)) if secure { - _, err := rand.Read(bytes) - if err != nil { - state.PushString(fmt.Sprintf("random_bytes: error: %v", err)) - return 1 + if _, err := rand.Read(bytes); err != nil { + return state.PushError("random_bytes: error: %v", err) } } else { stateRngsMu.Lock() @@ -225,8 +239,7 @@ func cryptoRandomBytes(state *luajit.State) int { stateRngsMu.Unlock() if !ok { - state.PushString("random_bytes: RNG not initialized") - return 1 + return state.PushError("random_bytes: RNG not initialized") } for i := range bytes { @@ -236,8 +249,10 @@ func cryptoRandomBytes(state *luajit.State) int { // Output format outputFormat := "binary" - if state.GetTop() >= 3 && state.IsString(3) { - outputFormat = state.ToString(3) + if state.GetTop() >= 3 { + if format, err := state.SafeToString(3); err == nil { + outputFormat = format + } } switch outputFormat { @@ -254,17 +269,25 @@ func cryptoRandomBytes(state *luajit.State) int { // cryptoRandomInt generates a random integer in range [min, max] func cryptoRandomInt(state *luajit.State) int { - if !state.IsNumber(1) || !state.IsNumber(2) { - state.PushString("random_int: expected (number min, number max)") - return 1 + if err := state.CheckMinArgs(2); err != nil { + return state.PushError("random_int: %v", err) } - min := int64(state.ToNumber(1)) - max := int64(state.ToNumber(2)) + minVal, err := state.SafeToNumber(1) + if err != nil { + return state.PushError("random_int: min must be number") + } + + maxVal, err := state.SafeToNumber(2) + if err != nil { + return state.PushError("random_int: max must be number") + } + + min := int64(minVal) + max := int64(maxVal) if max <= min { - state.PushString("random_int: max must be greater than min") - return 1 + return state.PushError("random_int: max must be greater than min") } // Check if secure @@ -274,15 +297,12 @@ func cryptoRandomInt(state *luajit.State) int { } range_size := max - min + 1 - var result int64 if secure { bytes := make([]byte, 8) - _, err := rand.Read(bytes) - if err != nil { - state.PushString(fmt.Sprintf("random_int: error: %v", err)) - return 1 + if _, err := rand.Read(bytes); err != nil { + return state.PushError("random_int: error: %v", err) } val := binary.BigEndian.Uint64(bytes) @@ -293,8 +313,7 @@ func cryptoRandomInt(state *luajit.State) int { stateRngsMu.Unlock() if !ok { - state.PushString("random_int: RNG not initialized") - return 1 + return state.PushError("random_int: RNG not initialized") } result = min + int64(stateRng.Uint64()%uint64(range_size)) @@ -315,10 +334,8 @@ func cryptoRandom(state *luajit.State) int { if numArgs == 0 { if secure { bytes := make([]byte, 8) - _, err := rand.Read(bytes) - if err != nil { - state.PushString(fmt.Sprintf("random: error: %v", err)) - return 1 + if _, err := rand.Read(bytes); err != nil { + return state.PushError("random: error: %v", err) } val := binary.BigEndian.Uint64(bytes) state.PushNumber(float64(val) / float64(math.MaxUint64)) @@ -328,8 +345,7 @@ func cryptoRandom(state *luajit.State) int { stateRngsMu.Unlock() if !ok { - state.PushString("random: RNG not initialized") - return 1 + return state.PushError("random: RNG not initialized") } state.PushNumber(float64(stateRng.Uint64()) / float64(math.MaxUint64)) @@ -341,8 +357,7 @@ func cryptoRandom(state *luajit.State) int { if numArgs == 1 && state.IsNumber(1) { n := int64(state.ToNumber(1)) if n < 1 { - state.PushString("random: upper bound must be >= 1") - return 1 + return state.PushError("random: upper bound must be >= 1") } state.PushNumber(1) // min @@ -357,21 +372,23 @@ func cryptoRandom(state *luajit.State) int { return cryptoRandomInt(state) } - state.PushString("random: invalid arguments") - return 1 + return state.PushError("random: invalid arguments") } // cryptoRandomSeed sets seed for non-secure RNG func cryptoRandomSeed(state *luajit.State) int { - if !state.IsNumber(1) { - state.PushString("randomseed: expected (number seed)") - return 1 + if err := state.CheckExactArgs(1); err != nil { + return state.PushError("randomseed: %v", err) } - seed := uint64(state.ToNumber(1)) + seed, err := state.SafeToNumber(1) + if err != nil { + return state.PushError("randomseed: seed must be number") + } + seedVal := uint64(seed) stateRngsMu.Lock() - stateRngs[state] = mrand.NewPCG(seed, seed>>32) + stateRngs[state] = mrand.NewPCG(seedVal, seedVal>>32) stateRngsMu.Unlock() return 0 diff --git a/runner/embed.go b/runner/embed.go index af2dcc9..a86dffc 100644 --- a/runner/embed.go +++ b/runner/embed.go @@ -71,7 +71,7 @@ var ( // precompileModule compiles a module's code to bytecode once func precompileModule(m *ModuleInfo) { m.Once.Do(func() { - tempState := luajit.New() + tempState := luajit.New(true) // Explicitly open standard libraries if tempState == nil { logger.Fatalf("Failed to create temp Lua state for %s module compilation", m.Name) return @@ -105,18 +105,14 @@ func loadModule(state *luajit.State, m *ModuleInfo, verbose bool) error { logger.Debugf("Loading %s.lua from precompiled bytecode", m.Name) } - if err := state.LoadBytecode(*bytecode, m.Name+".lua"); err != nil { - return err - } - if m.DefinesGlobal { // Module defines its own globals, just run it - if err := state.RunBytecode(); err != nil { + if err := state.LoadAndRunBytecode(*bytecode, m.Name+".lua"); err != nil { return err } } else { // Module returns a table, capture and set as global - if err := state.RunBytecodeWithResults(1); err != nil { + if err := state.LoadAndRunBytecodeWithResults(*bytecode, m.Name+".lua", 1); err != nil { return err } state.SetGlobal(m.Name) diff --git a/runner/env.go b/runner/env.go index 2be9d9a..4eeb3d7 100644 --- a/runner/env.go +++ b/runner/env.go @@ -221,12 +221,17 @@ func CleanupEnv() error { // envGet Lua function to get an environment variable func envGet(state *luajit.State) int { - if !state.IsString(1) { + if err := state.CheckExactArgs(1); err != nil { + state.PushNil() + return 1 + } + + key, err := state.SafeToString(1) + if err != nil { state.PushNil() return 1 } - key := state.ToString(1) if value, exists := globalEnvManager.Get(key); exists { if err := state.PushValue(value); err != nil { state.PushNil() @@ -239,13 +244,22 @@ func envGet(state *luajit.State) int { // envSet Lua function to set an environment variable func envSet(state *luajit.State) int { - if !state.IsString(1) || !state.IsString(2) { + if err := state.CheckExactArgs(2); err != nil { state.PushBoolean(false) return 1 } - key := state.ToString(1) - value := state.ToString(2) + key, err := state.SafeToString(1) + if err != nil { + state.PushBoolean(false) + return 1 + } + + value, err := state.SafeToString(2) + if err != nil { + state.PushBoolean(false) + return 1 + } globalEnvManager.Set(key, value) state.PushBoolean(true) @@ -255,11 +269,9 @@ func envSet(state *luajit.State) int { // envGetAll Lua function to get all environment variables func envGetAll(state *luajit.State) int { vars := globalEnvManager.GetAll() - - if err := state.PushTable(vars); err != nil { + if err := state.PushValue(vars); err != nil { state.PushNil() } - return 1 } diff --git a/runner/fs.go b/runner/fs.go index c2e7a28..a5d6ef1 100644 --- a/runner/fs.go +++ b/runner/fs.go @@ -112,23 +112,24 @@ func getCacheKey(fullPath string, modTime time.Time) string { // fsReadFile reads a file and returns its contents func fsReadFile(state *luajit.State) int { - if !state.IsString(1) { - state.PushString("fs.read_file: path must be a string") - return -1 + if err := state.CheckExactArgs(1); err != nil { + return state.PushError("fs.read_file: %v", err) + } + + path, err := state.SafeToString(1) + if err != nil { + return state.PushError("fs.read_file: path must be string") } - path := state.ToString(1) fullPath, err := ResolvePath(path) if err != nil { - state.PushString("fs.read_file: " + err.Error()) - return -1 + return state.PushError("fs.read_file: %v", err) } // Get file info for cache key and validation info, err := os.Stat(fullPath) if err != nil { - state.PushString("fs.read_file: " + err.Error()) - return -1 + return state.PushError("fs.read_file: %v", err) } // Create cache key with path and modification time @@ -154,8 +155,7 @@ func fsReadFile(state *luajit.State) int { stats.misses++ data, err := os.ReadFile(fullPath) if err != nil { - state.PushString("fs.read_file: " + err.Error()) - return -1 + return state.PushError("fs.read_file: %v", err) } // Compress and cache the data @@ -170,41 +170,33 @@ func fsReadFile(state *luajit.State) int { // fsWriteFile writes data to a file func fsWriteFile(state *luajit.State) int { - if !state.IsString(1) { - state.PushString("fs.write_file: path must be a string") - return -1 + if err := state.CheckExactArgs(2); err != nil { + return state.PushError("fs.write_file: %v", err) } - path := state.ToString(1) - if !state.IsString(2) { - state.PushString("fs.write_file: content must be a string") - return -1 + path, err := state.SafeToString(1) + if err != nil { + return state.PushError("fs.write_file: path must be string") + } + + content, err := state.SafeToString(2) + if err != nil { + return state.PushError("fs.write_file: content must be string") } - content := state.ToString(2) fullPath, err := ResolvePath(path) if err != nil { - state.PushString("fs.write_file: " + err.Error()) - return -1 + return state.PushError("fs.write_file: %v", err) } // Ensure the directory exists dir := filepath.Dir(fullPath) if err := os.MkdirAll(dir, 0755); err != nil { - state.PushString("fs.write_file: failed to create directory: " + err.Error()) - return -1 + return state.PushError("fs.write_file: failed to create directory: %v", err) } - err = os.WriteFile(fullPath, []byte(content), 0644) - if err != nil { - state.PushString("fs.write_file: " + err.Error()) - return -1 - } - - // Invalidate cache entries for this file path - if fileCache != nil { - // We can't easily iterate through cache keys, so we'll let the cache - // naturally expire old entries when the file is read again + if err := os.WriteFile(fullPath, []byte(content), 0644); err != nil { + return state.PushError("fs.write_file: %v", err) } state.PushBoolean(true) @@ -213,42 +205,39 @@ func fsWriteFile(state *luajit.State) int { // fsAppendFile appends data to a file func fsAppendFile(state *luajit.State) int { - if !state.IsString(1) { - state.PushString("fs.append_file: path must be a string") - return -1 + if err := state.CheckExactArgs(2); err != nil { + return state.PushError("fs.append_file: %v", err) } - path := state.ToString(1) - if !state.IsString(2) { - state.PushString("fs.append_file: content must be a string") - return -1 + path, err := state.SafeToString(1) + if err != nil { + return state.PushError("fs.append_file: path must be string") + } + + content, err := state.SafeToString(2) + if err != nil { + return state.PushError("fs.append_file: content must be string") } - content := state.ToString(2) fullPath, err := ResolvePath(path) if err != nil { - state.PushString("fs.append_file: " + err.Error()) - return -1 + return state.PushError("fs.append_file: %v", err) } // Ensure the directory exists dir := filepath.Dir(fullPath) if err := os.MkdirAll(dir, 0755); err != nil { - state.PushString("fs.append_file: failed to create directory: " + err.Error()) - return -1 + return state.PushError("fs.append_file: failed to create directory: %v", err) } file, err := os.OpenFile(fullPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { - state.PushString("fs.append_file: " + err.Error()) - return -1 + return state.PushError("fs.append_file: %v", err) } defer file.Close() - _, err = file.Write([]byte(content)) - if err != nil { - state.PushString("fs.append_file: " + err.Error()) - return -1 + if _, err = file.Write([]byte(content)); err != nil { + return state.PushError("fs.append_file: %v", err) } state.PushBoolean(true) @@ -257,16 +246,18 @@ func fsAppendFile(state *luajit.State) int { // fsExists checks if a file or directory exists func fsExists(state *luajit.State) int { - if !state.IsString(1) { - state.PushString("fs.exists: path must be a string") - return -1 + if err := state.CheckExactArgs(1); err != nil { + return state.PushError("fs.exists: %v", err) + } + + path, err := state.SafeToString(1) + if err != nil { + return state.PushError("fs.exists: path must be string") } - path := state.ToString(1) fullPath, err := ResolvePath(path) if err != nil { - state.PushString("fs.exists: " + err.Error()) - return -1 + return state.PushError("fs.exists: %v", err) } _, err = os.Stat(fullPath) @@ -276,34 +267,32 @@ func fsExists(state *luajit.State) int { // fsRemoveFile removes a file func fsRemoveFile(state *luajit.State) int { - if !state.IsString(1) { - state.PushString("fs.remove_file: path must be a string") - return -1 + if err := state.CheckExactArgs(1); err != nil { + return state.PushError("fs.remove_file: %v", err) + } + + path, err := state.SafeToString(1) + if err != nil { + return state.PushError("fs.remove_file: path must be string") } - path := state.ToString(1) fullPath, err := ResolvePath(path) if err != nil { - state.PushString("fs.remove_file: " + err.Error()) - return -1 + return state.PushError("fs.remove_file: %v", err) } // Check if it's a directory info, err := os.Stat(fullPath) if err != nil { - state.PushString("fs.remove_file: " + err.Error()) - return -1 + return state.PushError("fs.remove_file: %v", err) } if info.IsDir() { - state.PushString("fs.remove_file: cannot remove directory, use remove_dir instead") - return -1 + return state.PushError("fs.remove_file: cannot remove directory, use remove_dir instead") } - err = os.Remove(fullPath) - if err != nil { - state.PushString("fs.remove_file: " + err.Error()) - return -1 + if err := os.Remove(fullPath); err != nil { + return state.PushError("fs.remove_file: %v", err) } state.PushBoolean(true) @@ -312,67 +301,65 @@ func fsRemoveFile(state *luajit.State) int { // fsGetInfo gets information about a file func fsGetInfo(state *luajit.State) int { - if !state.IsString(1) { - state.PushString("fs.get_info: path must be a string") - return -1 + if err := state.CheckExactArgs(1); err != nil { + return state.PushError("fs.get_info: %v", err) + } + + path, err := state.SafeToString(1) + if err != nil { + return state.PushError("fs.get_info: path must be string") } - path := state.ToString(1) fullPath, err := ResolvePath(path) if err != nil { - state.PushString("fs.get_info: " + err.Error()) - return -1 + return state.PushError("fs.get_info: %v", err) } info, err := os.Stat(fullPath) if err != nil { - state.PushString("fs.get_info: " + err.Error()) - return -1 + return state.PushError("fs.get_info: %v", err) } - state.NewTable() + fileInfo := map[string]any{ + "name": info.Name(), + "size": info.Size(), + "mode": int(info.Mode()), + "mod_time": info.ModTime().Unix(), + "is_dir": info.IsDir(), + } - state.PushString(info.Name()) - state.SetField(-2, "name") - - state.PushNumber(float64(info.Size())) - state.SetField(-2, "size") - - state.PushNumber(float64(info.Mode())) - state.SetField(-2, "mode") - - state.PushNumber(float64(info.ModTime().Unix())) - state.SetField(-2, "mod_time") - - state.PushBoolean(info.IsDir()) - state.SetField(-2, "is_dir") + if err := state.PushValue(fileInfo); err != nil { + return state.PushError("fs.get_info: %v", err) + } return 1 } // fsMakeDir creates a directory func fsMakeDir(state *luajit.State) int { - if !state.IsString(1) { - state.PushString("fs.make_dir: path must be a string") - return -1 + if err := state.CheckMinArgs(1); err != nil { + return state.PushError("fs.make_dir: %v", err) + } + + path, err := state.SafeToString(1) + if err != nil { + return state.PushError("fs.make_dir: path must be string") } - path := state.ToString(1) perm := os.FileMode(0755) - if state.GetTop() >= 2 && state.IsNumber(2) { - perm = os.FileMode(state.ToNumber(2)) + if state.GetTop() >= 2 { + if permVal, err := state.SafeToNumber(2); err == nil { + perm = os.FileMode(permVal) + } } fullPath, err := ResolvePath(path) if err != nil { - state.PushString("fs.make_dir: " + err.Error()) - return -1 + return state.PushError("fs.make_dir: %v", err) } - err = os.MkdirAll(fullPath, perm) - if err != nil { - state.PushString("fs.make_dir: " + err.Error()) - return -1 + if err := os.MkdirAll(fullPath, perm); err != nil { + return state.PushError("fs.make_dir: %v", err) } state.PushBoolean(true) @@ -381,41 +368,42 @@ func fsMakeDir(state *luajit.State) int { // fsListDir lists the contents of a directory func fsListDir(state *luajit.State) int { - if !state.IsString(1) { - state.PushString("fs.list_dir: path must be a string") - return -1 + if err := state.CheckExactArgs(1); err != nil { + return state.PushError("fs.list_dir: %v", err) + } + + path, err := state.SafeToString(1) + if err != nil { + return state.PushError("fs.list_dir: path must be string") } - path := state.ToString(1) fullPath, err := ResolvePath(path) if err != nil { - state.PushString("fs.list_dir: " + err.Error()) - return -1 + return state.PushError("fs.list_dir: %v", err) } info, err := os.Stat(fullPath) if err != nil { - state.PushString("fs.list_dir: " + err.Error()) - return -1 + return state.PushError("fs.list_dir: %v", err) } if !info.IsDir() { - state.PushString("fs.list_dir: not a directory") - return -1 + return state.PushError("fs.list_dir: not a directory") } files, err := os.ReadDir(fullPath) if err != nil { - state.PushString("fs.list_dir: " + err.Error()) - return -1 + return state.PushError("fs.list_dir: %v", err) } - state.NewTable() - + // Create array of filenames + filenames := make([]string, len(files)) for i, file := range files { - state.PushNumber(float64(i + 1)) - state.PushString(file.Name()) - state.SetTable(-3) + filenames[i] = file.Name() + } + + if err := state.PushValue(filenames); err != nil { + return state.PushError("fs.list_dir: %v", err) } return 1 @@ -423,11 +411,14 @@ func fsListDir(state *luajit.State) int { // fsRemoveDir removes a directory func fsRemoveDir(state *luajit.State) int { - if !state.IsString(1) { - state.PushString("fs.remove_dir: path must be a string") - return -1 + if err := state.CheckMinArgs(1); err != nil { + return state.PushError("fs.remove_dir: %v", err) + } + + path, err := state.SafeToString(1) + if err != nil { + return state.PushError("fs.remove_dir: path must be string") } - path := state.ToString(1) recursive := false if state.GetTop() >= 2 { @@ -436,19 +427,16 @@ func fsRemoveDir(state *luajit.State) int { fullPath, err := ResolvePath(path) if err != nil { - state.PushString("fs.remove_dir: " + err.Error()) - return -1 + return state.PushError("fs.remove_dir: %v", err) } info, err := os.Stat(fullPath) if err != nil { - state.PushString("fs.remove_dir: " + err.Error()) - return -1 + return state.PushError("fs.remove_dir: %v", err) } if !info.IsDir() { - state.PushString("fs.remove_dir: not a directory") - return -1 + return state.PushError("fs.remove_dir: not a directory") } if recursive { @@ -458,8 +446,7 @@ func fsRemoveDir(state *luajit.State) int { } if err != nil { - state.PushString("fs.remove_dir: " + err.Error()) - return -1 + return state.PushError("fs.remove_dir: %v", err) } state.PushBoolean(true) @@ -468,19 +455,17 @@ func fsRemoveDir(state *luajit.State) int { // fsJoinPaths joins path components func fsJoinPaths(state *luajit.State) int { - nargs := state.GetTop() - if nargs < 1 { - state.PushString("fs.join_paths: at least one path component required") - return -1 + if err := state.CheckMinArgs(1); err != nil { + return state.PushError("fs.join_paths: %v", err) } - components := make([]string, nargs) - for i := 1; i <= nargs; i++ { - if !state.IsString(i) { - state.PushString("fs.join_paths: all arguments must be strings") - return -1 + components := make([]string, state.GetTop()) + for i := 1; i <= state.GetTop(); i++ { + comp, err := state.SafeToString(i) + if err != nil { + return state.PushError("fs.join_paths: all arguments must be strings") } - components[i-1] = state.ToString(i) + components[i-1] = comp } result := filepath.Join(components...) @@ -492,11 +477,14 @@ func fsJoinPaths(state *luajit.State) int { // fsDirName returns the directory portion of a path func fsDirName(state *luajit.State) int { - if !state.IsString(1) { - state.PushString("fs.dir_name: path must be a string") - return -1 + if err := state.CheckExactArgs(1); err != nil { + return state.PushError("fs.dir_name: %v", err) + } + + path, err := state.SafeToString(1) + if err != nil { + return state.PushError("fs.dir_name: path must be string") } - path := state.ToString(1) dir := filepath.Dir(path) dir = strings.ReplaceAll(dir, "\\", "/") @@ -507,28 +495,32 @@ func fsDirName(state *luajit.State) int { // fsBaseName returns the file name portion of a path func fsBaseName(state *luajit.State) int { - if !state.IsString(1) { - state.PushString("fs.base_name: path must be a string") - return -1 + if err := state.CheckExactArgs(1); err != nil { + return state.PushError("fs.base_name: %v", err) + } + + path, err := state.SafeToString(1) + if err != nil { + return state.PushError("fs.base_name: path must be string") } - path := state.ToString(1) base := filepath.Base(path) - state.PushString(base) return 1 } // fsExtension returns the file extension func fsExtension(state *luajit.State) int { - if !state.IsString(1) { - state.PushString("fs.extension: path must be a string") - return -1 + if err := state.CheckExactArgs(1); err != nil { + return state.PushError("fs.extension: %v", err) + } + + path, err := state.SafeToString(1) + if err != nil { + return state.PushError("fs.extension: path must be string") } - path := state.ToString(1) ext := filepath.Ext(path) - state.PushString(ext) return 1 } diff --git a/runner/http.go b/runner/http.go index 830bca6..612e3e5 100644 --- a/runner/http.go +++ b/runner/http.go @@ -94,25 +94,26 @@ func ApplyResponse(resp *Response, ctx *fasthttp.RequestCtx) { // httpRequest makes an HTTP request and returns the result to Lua func httpRequest(state *luajit.State) int { - // Get method (required) - if !state.IsString(1) { - state.PushString("http.client.request: method must be a string") - return -1 + if err := state.CheckMinArgs(2); err != nil { + return state.PushError("http.client.request: %v", err) } - method := strings.ToUpper(state.ToString(1)) - // Get URL (required) - if !state.IsString(2) { - state.PushString("http.client.request: url must be a string") - return -1 + // Get method and URL + method, err := state.SafeToString(1) + if err != nil { + return state.PushError("http.client.request: method must be string") + } + method = strings.ToUpper(method) + + urlStr, err := state.SafeToString(2) + if err != nil { + return state.PushError("http.client.request: url must be string") } - urlStr := state.ToString(2) // Parse URL to check if it's valid parsedURL, err := url.Parse(urlStr) if err != nil { - state.PushString("Invalid URL: " + err.Error()) - return -1 + return state.PushError("Invalid URL: %v", err) } // Get client configuration @@ -120,8 +121,7 @@ func httpRequest(state *luajit.State) int { // Check if remote connections are allowed if !config.AllowRemote && (parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1") { - state.PushString("Remote connections are not allowed") - return -1 + return state.PushError("Remote connections are not allowed") } // Use bytebufferpool for request and response @@ -139,13 +139,13 @@ func httpRequest(state *luajit.State) int { if state.GetTop() >= 3 && !state.IsNil(3) { if state.IsString(3) { // String body - req.SetBodyString(state.ToString(3)) + bodyStr, _ := state.SafeToString(3) + req.SetBodyString(bodyStr) } else if state.IsTable(3) { // Table body - convert to JSON - luaTable, err := state.ToTable(3) + luaTable, err := state.SafeToTable(3) if err != nil { - state.PushString("Failed to parse body table: " + err.Error()) - return -1 + return state.PushError("Failed to parse body table: %v", err) } // Use bytebufferpool for JSON serialization @@ -153,42 +153,37 @@ func httpRequest(state *luajit.State) int { defer bytebufferpool.Put(buf) if err := json.NewEncoder(buf).Encode(luaTable); err != nil { - state.PushString("Failed to convert body to JSON: " + err.Error()) - return -1 + return state.PushError("Failed to convert body to JSON: %v", err) } req.SetBody(buf.Bytes()) req.Header.SetContentType("application/json") } else { - state.PushString("Body must be a string or table") - return -1 + return state.PushError("Body must be a string or table") } } // Process options (headers, timeout, etc.) timeout := config.DefaultTimeout if state.GetTop() >= 4 && !state.IsNil(4) && state.IsTable(4) { - // Process headers - state.GetField(4, "headers") - if state.IsTable(-1) { - // Iterate through headers - state.PushNil() // Start iteration - for state.Next(-2) { - // Stack now has key at -2 and value at -1 - if state.IsString(-2) && state.IsString(-1) { - headerName := state.ToString(-2) - headerValue := state.ToString(-1) - req.Header.Set(headerName, headerValue) + // Process headers using ForEachTableKV + if headers, ok := state.GetFieldTable(4, "headers"); ok { + if headerMap, ok := headers.(map[string]string); ok { + for name, value := range headerMap { + req.Header.Set(name, value) + } + } else if headerMapAny, ok := headers.(map[string]any); ok { + for name, value := range headerMapAny { + if valueStr, ok := value.(string); ok { + req.Header.Set(name, valueStr) + } } - state.Pop(1) // Pop value, leave key for next iteration } } - state.Pop(1) // Pop headers table // Get timeout - state.GetField(4, "timeout") - if state.IsNumber(-1) { - requestTimeout := time.Duration(state.ToNumber(-1)) * time.Second + if timeoutVal := state.GetFieldNumber(4, "timeout", 0); timeoutVal > 0 { + requestTimeout := time.Duration(timeoutVal) * time.Second // Apply max timeout if configured if config.MaxTimeout > 0 && requestTimeout > config.MaxTimeout { @@ -197,38 +192,34 @@ func httpRequest(state *luajit.State) int { timeout = requestTimeout } } - state.Pop(1) // Pop timeout // Process query parameters - state.GetField(4, "query") - if state.IsTable(-1) { - // Create URL args + if query, ok := state.GetFieldTable(4, "query"); ok { args := req.URI().QueryArgs() - // Iterate through query params - state.PushNil() // Start iteration - for state.Next(-2) { - if state.IsString(-2) { - paramName := state.ToString(-2) - - // Handle different value types - if state.IsString(-1) { - args.Add(paramName, state.ToString(-1)) - } else if state.IsNumber(-1) { - args.Add(paramName, strings.TrimRight(strings.TrimRight( - state.ToString(-1), "0"), ".")) - } else if state.IsBoolean(-1) { - if state.ToBoolean(-1) { - args.Add(paramName, "true") + if queryMap, ok := query.(map[string]string); ok { + for name, value := range queryMap { + args.Add(name, value) + } + } else if queryMapAny, ok := query.(map[string]any); ok { + for name, value := range queryMapAny { + switch v := value.(type) { + case string: + args.Add(name, v) + case int: + args.Add(name, fmt.Sprintf("%d", v)) + case float64: + args.Add(name, strings.TrimRight(strings.TrimRight(fmt.Sprintf("%.6f", v), "0"), ".")) + case bool: + if v { + args.Add(name, "true") } else { - args.Add(paramName, "false") + args.Add(name, "false") } } } - state.Pop(1) // Pop value, leave key for next iteration } } - state.Pop(1) // Pop query table } // Create context with timeout @@ -242,26 +233,18 @@ func httpRequest(state *luajit.State) int { if errors.Is(err, fasthttp.ErrTimeout) { errStr = "Request timed out after " + timeout.String() } - state.PushString(errStr) - return -1 + return state.PushError("%s", errStr) } - // Create response table - state.NewTable() + // Create response using TableBuilder + builder := state.NewTableBuilder() - // Set status code - state.PushNumber(float64(resp.StatusCode())) - state.SetField(-2, "status") - - // Set status text - statusText := fasthttp.StatusMessage(resp.StatusCode()) - state.PushString(statusText) - state.SetField(-2, "status_text") + // Set status code and text + builder.SetNumber("status", float64(resp.StatusCode())) + builder.SetString("status_text", fasthttp.StatusMessage(resp.StatusCode())) // Set body var respBody []byte - - // Apply size limits to response if config.MaxResponseSize > 0 && int64(len(resp.Body())) > config.MaxResponseSize { // Make a limited copy respBody = make([]byte, config.MaxResponseSize) @@ -270,32 +253,28 @@ func httpRequest(state *luajit.State) int { respBody = resp.Body() } - state.PushString(string(respBody)) - state.SetField(-2, "body") + builder.SetString("body", string(respBody)) // Parse body as JSON if content type is application/json contentType := string(resp.Header.ContentType()) if strings.Contains(contentType, "application/json") { var jsonData any if err := json.Unmarshal(respBody, &jsonData); err == nil { - if err := state.PushValue(jsonData); err == nil { - state.SetField(-2, "json") - } + builder.SetTable("json", jsonData) } } // Set headers - state.NewTable() + headers := make(map[string]string) resp.Header.VisitAll(func(key, value []byte) { - state.PushString(string(value)) - state.SetField(-2, string(key)) + headers[string(key)] = string(value) }) - state.SetField(-2, "headers") + builder.SetTable("headers", headers) // Create ok field (true if status code is 2xx) - state.PushBoolean(resp.StatusCode() >= 200 && resp.StatusCode() < 300) - state.SetField(-2, "ok") + builder.SetBool("ok", resp.StatusCode() >= 200 && resp.StatusCode() < 300) + builder.Build() return 1 } @@ -303,8 +282,10 @@ func httpRequest(state *luajit.State) int { func generateToken(state *luajit.State) int { // Get the length from the Lua arguments (default to 32) length := 32 - if state.GetTop() >= 1 && state.IsNumber(1) { - length = int(state.ToNumber(1)) + if state.GetTop() >= 1 { + if lengthVal, err := state.SafeToNumber(1); err == nil { + length = int(lengthVal) + } } // Enforce minimum length for security diff --git a/runner/password.go b/runner/password.go index d96b348..a85105e 100644 --- a/runner/password.go +++ b/runner/password.go @@ -1,8 +1,6 @@ package runner import ( - "fmt" - luajit "git.sharkk.net/Sky/LuaJIT-to-Go" "github.com/alexedwards/argon2id" ) @@ -20,12 +18,14 @@ func RegisterPasswordFunctions(state *luajit.State) error { // passwordHash implements the Argon2id password hashing using alexedwards/argon2id func passwordHash(state *luajit.State) int { - if !state.IsString(1) { - state.PushString("password_hash error: expected string password") - return 1 + if err := state.CheckMinArgs(1); err != nil { + return state.PushError("password_hash: %v", err) } - password := state.ToString(1) + password, err := state.SafeToString(1) + if err != nil { + return state.PushError("password_hash: password must be string") + } params := &argon2id.Params{ Memory: 128 * 1024, @@ -35,42 +35,32 @@ func passwordHash(state *luajit.State) int { KeyLength: 32, } - if state.IsTable(2) { - state.GetField(2, "memory") - if state.IsNumber(-1) { - params.Memory = max(uint32(state.ToNumber(-1)), 8*1024) + if state.GetTop() >= 2 && state.IsTable(2) { + // Use new field getters with validation + if memory := state.GetFieldNumber(2, "memory", 0); memory > 0 { + params.Memory = max(uint32(memory), 8*1024) } - state.Pop(1) - state.GetField(2, "iterations") - if state.IsNumber(-1) { - params.Iterations = max(uint32(state.ToNumber(-1)), 1) + if iterations := state.GetFieldNumber(2, "iterations", 0); iterations > 0 { + params.Iterations = max(uint32(iterations), 1) } - state.Pop(1) - state.GetField(2, "parallelism") - if state.IsNumber(-1) { - params.Parallelism = max(uint8(state.ToNumber(-1)), 1) + if parallelism := state.GetFieldNumber(2, "parallelism", 0); parallelism > 0 { + params.Parallelism = max(uint8(parallelism), 1) } - state.Pop(1) - state.GetField(2, "salt_length") - if state.IsNumber(-1) { - params.SaltLength = max(uint32(state.ToNumber(-1)), 8) + if saltLength := state.GetFieldNumber(2, "salt_length", 0); saltLength > 0 { + params.SaltLength = max(uint32(saltLength), 8) } - state.Pop(1) - state.GetField(2, "key_length") - if state.IsNumber(-1) { - params.KeyLength = max(uint32(state.ToNumber(-1)), 16) + if keyLength := state.GetFieldNumber(2, "key_length", 0); keyLength > 0 { + params.KeyLength = max(uint32(keyLength), 16) } - state.Pop(1) } hash, err := argon2id.CreateHash(password, params) if err != nil { - state.PushString(fmt.Sprintf("password_hash error: %v", err)) - return 1 + return state.PushError("password_hash: %v", err) } state.PushString(hash) @@ -79,13 +69,22 @@ func passwordHash(state *luajit.State) int { // passwordVerify verifies a password against a hash func passwordVerify(state *luajit.State) int { - if !state.IsString(1) || !state.IsString(2) { + if err := state.CheckExactArgs(2); err != nil { state.PushBoolean(false) return 1 } - password := state.ToString(1) - hash := state.ToString(2) + password, err := state.SafeToString(1) + if err != nil { + state.PushBoolean(false) + return 1 + } + + hash, err := state.SafeToString(2) + if err != nil { + state.PushBoolean(false) + return 1 + } match, err := argon2id.ComparePasswordAndHash(password, hash) if err != nil { diff --git a/runner/runner.go b/runner/runner.go index f8fddc1..ea3b478 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -156,7 +156,7 @@ func (r *Runner) createState(index int) (*State, error) { logger.Debugf("Creating Lua state %d", index) } - L := luajit.New() + L := luajit.New(true) // Explicitly open standard libraries if L == nil { return nil, errors.New("failed to create Lua state") } diff --git a/runner/sandbox.go b/runner/sandbox.go index ae0ec65..8153fdc 100644 --- a/runner/sandbox.go +++ b/runner/sandbox.go @@ -130,42 +130,30 @@ func (s *Sandbox) registerCoreFunctions(state *luajit.State) error { // Execute runs a Lua script in the sandbox with the given context func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) { - state.GetGlobal("__execute_script") - if !state.IsFunction(-1) { - state.Pop(1) - return nil, ErrSandboxNotInitialized - } + // Use CallGlobal for cleaner function calling + results, err := state.CallGlobal("__execute_script", + func() any { + if err := state.LoadBytecode(bytecode, "script"); err != nil { + return nil + } + return "loaded_script" + }(), + ctx.Values) - if err := state.LoadBytecode(bytecode, "script"); err != nil { - state.Pop(1) // Pop the __execute_script function - return nil, fmt.Errorf("failed to load script: %w", err) - } - - // Push context values - if err := state.PushTable(ctx.Values); err != nil { - state.Pop(2) // Pop bytecode and __execute_script - return nil, err - } - - // Execute with 2 args, 1 result - if err := state.Call(2, 1); err != nil { + if err != nil { return nil, fmt.Errorf("script execution failed: %w", err) } - body, err := state.ToValue(-1) - state.Pop(1) - response := NewResponse() - if err == nil { - response.Body = body + if len(results) > 0 { + response.Body = results[0] } extractHTTPResponseData(state, response) - return response, nil } -// extractResponseData pulls response info from the Lua state +// extractResponseData pulls response info from the Lua state using new API func extractHTTPResponseData(state *luajit.State, response *Response) { state.GetGlobal("__http_response") if !state.IsTable(-1) { @@ -173,162 +161,113 @@ func extractHTTPResponseData(state *luajit.State, response *Response) { return } - // Extract status - state.GetField(-1, "status") - if state.IsNumber(-1) { - response.Status = int(state.ToNumber(-1)) - } - state.Pop(1) + // Use new field getters with defaults + response.Status = int(state.GetFieldNumber(-1, "status", 200)) - // Extract headers - state.GetField(-1, "headers") - if state.IsTable(-1) { - state.PushNil() // Start iteration - for state.Next(-2) { - if state.IsString(-2) && state.IsString(-1) { - key := state.ToString(-2) - value := state.ToString(-1) - response.Headers[key] = value + // Extract headers using ForEachTableKV + if headerTable, ok := state.GetFieldTable(-1, "headers"); ok { + if headers, ok := headerTable.(map[string]any); ok { + for k, v := range headers { + if str, ok := v.(string); ok { + response.Headers[k] = str + } } - state.Pop(1) } } - state.Pop(1) - // Extract cookies + // Extract cookies using ForEachArray state.GetField(-1, "cookies") if state.IsTable(-1) { - length := state.GetTableLength(-1) - for i := 1; i <= length; i++ { - state.PushNumber(float64(i)) - state.GetTable(-2) - - if state.IsTable(-1) { - extractCookie(state, response) + state.ForEachArray(-1, func(i int, s *luajit.State) bool { + if s.IsTable(-1) { + extractCookie(s, response) } - state.Pop(1) - } + return true + }) } state.Pop(1) // Extract metadata - state.GetField(-1, "metadata") - if state.IsTable(-1) { - table, err := state.ToTable(-1) - if err == nil { - maps.Copy(response.Metadata, table) + if metadata, ok := state.GetFieldTable(-1, "metadata"); ok { + if metaMap, ok := metadata.(map[string]any); ok { + maps.Copy(response.Metadata, metaMap) } } - state.Pop(1) // Extract session data - state.GetField(-1, "session") - if state.IsTable(-1) { - table, err := state.ToTable(-1) - if err == nil { - maps.Copy(response.SessionData, table) + if session, ok := state.GetFieldTable(-1, "session"); ok { + if sessMap, ok := session.(map[string]any); ok { + maps.Copy(response.SessionData, sessMap) } } - state.Pop(1) state.Pop(1) // Pop __http_response } -// extractCookie pulls cookie data from the current table on the stack +// extractCookie pulls cookie data from the current table on the stack using new API func extractCookie(state *luajit.State, response *Response) { cookie := fasthttp.AcquireCookie() - // Get name (required) - state.GetField(-1, "name") - if !state.IsString(-1) { - state.Pop(1) + // Use new field getters with defaults + name := state.GetFieldString(-1, "name", "") + if name == "" { fasthttp.ReleaseCookie(cookie) return } - cookie.SetKey(state.ToString(-1)) - state.Pop(1) - // Get value - state.GetField(-1, "value") - if state.IsString(-1) { - cookie.SetValue(state.ToString(-1)) - } - state.Pop(1) - - // Get path - state.GetField(-1, "path") - if state.IsString(-1) { - cookie.SetPath(state.ToString(-1)) - } else { - cookie.SetPath("/") // Default - } - state.Pop(1) - - // Get domain - state.GetField(-1, "domain") - if state.IsString(-1) { - cookie.SetDomain(state.ToString(-1)) - } - state.Pop(1) - - // Get other parameters - state.GetField(-1, "http_only") - if state.IsBoolean(-1) { - cookie.SetHTTPOnly(state.ToBoolean(-1)) - } - state.Pop(1) - - state.GetField(-1, "secure") - if state.IsBoolean(-1) { - cookie.SetSecure(state.ToBoolean(-1)) - } - state.Pop(1) - - state.GetField(-1, "max_age") - if state.IsNumber(-1) { - cookie.SetMaxAge(int(state.ToNumber(-1))) - } - state.Pop(1) + cookie.SetKey(name) + cookie.SetValue(state.GetFieldString(-1, "value", "")) + cookie.SetPath(state.GetFieldString(-1, "path", "/")) + cookie.SetDomain(state.GetFieldString(-1, "domain", "")) + cookie.SetHTTPOnly(state.GetFieldBool(-1, "http_only", false)) + cookie.SetSecure(state.GetFieldBool(-1, "secure", false)) + cookie.SetMaxAge(int(state.GetFieldNumber(-1, "max_age", 0))) response.Cookies = append(response.Cookies, cookie) } -// jsonMarshal converts a Lua value to a JSON string +// jsonMarshal converts a Lua value to a JSON string with validation func jsonMarshal(state *luajit.State) int { - value, err := state.ToValue(1) + if err := state.CheckExactArgs(1); err != nil { + return state.PushError("json marshal: %v", err) + } + + value, err := state.SafeToTable(1) if err != nil { - state.PushString(fmt.Sprintf("json marshal error: %v", err)) - return 1 + // Try as generic value if not a table + value, err = state.ToValue(1) + if err != nil { + return state.PushError("json marshal error: %v", err) + } } bytes, err := json.Marshal(value) if err != nil { - state.PushString(fmt.Sprintf("json marshal error: %v", err)) - return 1 + return state.PushError("json marshal error: %v", err) } state.PushString(string(bytes)) return 1 } -// jsonUnmarshal converts a JSON string to a Lua value +// jsonUnmarshal converts a JSON string to a Lua value with validation func jsonUnmarshal(state *luajit.State) int { - if !state.IsString(1) { - state.PushString("json unmarshal error: expected string") - return 1 + if err := state.CheckExactArgs(1); err != nil { + return state.PushError("json unmarshal: %v", err) + } + + jsonStr, err := state.SafeToString(1) + if err != nil { + return state.PushError("json unmarshal: expected string, got %s", state.GetType(1)) } - jsonStr := state.ToString(1) var value any - err := json.Unmarshal([]byte(jsonStr), &value) - if err != nil { - state.PushString(fmt.Sprintf("json unmarshal error: %v", err)) - return 1 + if err := json.Unmarshal([]byte(jsonStr), &value); err != nil { + return state.PushError("json unmarshal error: %v", err) } if err := state.PushValue(value); err != nil { - state.PushString(fmt.Sprintf("json unmarshal error: %v", err)) - return 1 + return state.PushError("json unmarshal error: %v", err) } return 1 } diff --git a/runner/sqlite.go b/runner/sqlite.go index d41107f..80ca6a5 100644 --- a/runner/sqlite.go +++ b/runner/sqlite.go @@ -111,20 +111,24 @@ func getPool(dbName string) (*sqlitex.Pool, error) { // sqlQuery executes a SQL query and returns results func sqlQuery(state *luajit.State) int { - // Get required parameters - if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) { - state.PushString("sqlite.query: requires database name and query") - return -1 + if err := state.CheckMinArgs(2); err != nil { + return state.PushError("sqlite.query: %v", err) } - dbName := state.ToString(1) - query := state.ToString(2) + dbName, err := state.SafeToString(1) + if err != nil { + return state.PushError("sqlite.query: database name must be string") + } + + query, err := state.SafeToString(2) + if err != nil { + return state.PushError("sqlite.query: query must be string") + } // Get pool pool, err := getPool(dbName) if err != nil { - state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) - return -1 + return state.PushError("sqlite.query: %v", err) } // Get connection with timeout @@ -133,8 +137,7 @@ func sqlQuery(state *luajit.State) int { conn, err := pool.Take(ctx) if err != nil { - state.PushString(fmt.Sprintf("sqlite.query: connection timeout: %s", err.Error())) - return -1 + return state.PushError("sqlite.query: connection timeout: %v", err) } defer pool.Put(conn) @@ -145,8 +148,7 @@ func sqlQuery(state *luajit.State) int { // Set up parameters if provided if state.GetTop() >= 3 && !state.IsNil(3) { if err := setupParams(state, 3, &execOpts); err != nil { - state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) - return -1 + return state.PushError("sqlite.query: %v", err) } } @@ -182,19 +184,12 @@ func sqlQuery(state *luajit.State) int { // Execute query if err := sqlitex.Execute(conn, query, &execOpts); err != nil { - state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) - return -1 + return state.PushError("sqlite.query: %v", err) } - // Create result table - state.NewTable() - for i, row := range rows { - state.PushNumber(float64(i + 1)) - if err := state.PushTable(row); err != nil { - state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) - return -1 - } - state.SetTable(-3) + // Create result using specific map type and PushValue + if err := state.PushValue(rows); err != nil { + return state.PushError("sqlite.query: %v", err) } return 1 @@ -202,20 +197,24 @@ func sqlQuery(state *luajit.State) int { // sqlExec executes a SQL statement without returning results func sqlExec(state *luajit.State) int { - // Get required parameters - if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) { - state.PushString("sqlite.exec: requires database name and query") - return -1 + if err := state.CheckMinArgs(2); err != nil { + return state.PushError("sqlite.exec: %v", err) } - dbName := state.ToString(1) - query := state.ToString(2) + dbName, err := state.SafeToString(1) + if err != nil { + return state.PushError("sqlite.exec: database name must be string") + } + + query, err := state.SafeToString(2) + if err != nil { + return state.PushError("sqlite.exec: query must be string") + } // Get pool pool, err := getPool(dbName) if err != nil { - state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) - return -1 + return state.PushError("sqlite.exec: %v", err) } // Get connection with timeout @@ -224,8 +223,7 @@ func sqlExec(state *luajit.State) int { conn, err := pool.Take(ctx) if err != nil { - state.PushString(fmt.Sprintf("sqlite.exec: connection timeout: %s", err.Error())) - return -1 + return state.PushError("sqlite.exec: connection timeout: %v", err) } defer pool.Put(conn) @@ -235,8 +233,7 @@ func sqlExec(state *luajit.State) int { // Fast path for multi-statement scripts if strings.Contains(query, ";") && !hasParams { if err := sqlitex.ExecScript(conn, query); err != nil { - state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) - return -1 + return state.PushError("sqlite.exec: %v", err) } state.PushNumber(float64(conn.Changes())) return 1 @@ -245,8 +242,7 @@ func sqlExec(state *luajit.State) int { // Fast path for simple queries with no parameters if !hasParams { if err := sqlitex.Execute(conn, query, nil); err != nil { - state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) - return -1 + return state.PushError("sqlite.exec: %v", err) } state.PushNumber(float64(conn.Changes())) return 1 @@ -255,14 +251,12 @@ func sqlExec(state *luajit.State) int { // Create execution options for parameterized query var execOpts sqlitex.ExecOptions if err := setupParams(state, 3, &execOpts); err != nil { - state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) - return -1 + return state.PushError("sqlite.exec: %v", err) } // Execute with parameters if err := sqlitex.Execute(conn, query, &execOpts); err != nil { - state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) - return -1 + return state.PushError("sqlite.exec: %v", err) } // Return affected rows @@ -273,11 +267,17 @@ func sqlExec(state *luajit.State) int { // setupParams configures execution options with parameters from Lua func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error { if state.IsTable(paramIndex) { - params, err := state.ToTable(paramIndex) + paramsAny, err := state.SafeToTable(paramIndex) if err != nil { return fmt.Errorf("invalid parameters: %w", err) } + // Type assert to map[string]any + params, ok := paramsAny.(map[string]any) + if !ok { + return fmt.Errorf("parameters must be a table") + } + // Check for array-style params if arr, ok := params[""]; ok { if arrParams, ok := arr.([]any); ok { @@ -321,20 +321,24 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti // sqlGetOne executes a query and returns only the first row func sqlGetOne(state *luajit.State) int { - // Get required parameters - if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) { - state.PushString("sqlite.get_one: requires database name and query") - return -1 + if err := state.CheckMinArgs(2); err != nil { + return state.PushError("sqlite.get_one: %v", err) } - dbName := state.ToString(1) - query := state.ToString(2) + dbName, err := state.SafeToString(1) + if err != nil { + return state.PushError("sqlite.get_one: database name must be string") + } + + query, err := state.SafeToString(2) + if err != nil { + return state.PushError("sqlite.get_one: query must be string") + } // Get pool pool, err := getPool(dbName) if err != nil { - state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error())) - return -1 + return state.PushError("sqlite.get_one: %v", err) } // Get connection with timeout @@ -343,8 +347,7 @@ func sqlGetOne(state *luajit.State) int { conn, err := pool.Take(ctx) if err != nil { - state.PushString(fmt.Sprintf("sqlite.get_one: connection timeout: %s", err.Error())) - return -1 + return state.PushError("sqlite.get_one: connection timeout: %v", err) } defer pool.Put(conn) @@ -355,8 +358,7 @@ func sqlGetOne(state *luajit.State) int { // Set up parameters if provided if state.GetTop() >= 3 && !state.IsNil(3) { if err := setupParams(state, 3, &execOpts); err != nil { - state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error())) - return -1 + return state.PushError("sqlite.get_one: %v", err) } } @@ -395,17 +397,15 @@ func sqlGetOne(state *luajit.State) int { // Execute query if err := sqlitex.Execute(conn, query, &execOpts); err != nil { - state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error())) - return -1 + return state.PushError("sqlite.get_one: %v", err) } // Return result or nil if no rows if result == nil { state.PushNil() } else { - if err := state.PushTable(result); err != nil { - state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error())) - return -1 + if err := state.PushValue(result); err != nil { + return state.PushError("sqlite.get_one: %v", err) } } diff --git a/runner/util.go b/runner/util.go index e8c3010..7bd2e49 100644 --- a/runner/util.go +++ b/runner/util.go @@ -35,12 +35,17 @@ func RegisterUtilFunctions(state *luajit.State) error { // htmlSpecialChars converts special characters to HTML entities func htmlSpecialChars(state *luajit.State) int { - if !state.IsString(1) { + if err := state.CheckExactArgs(1); err != nil { + state.PushNil() + return 1 + } + + input, err := state.SafeToString(1) + if err != nil { state.PushNil() return 1 } - input := state.ToString(1) result := html.EscapeString(input) state.PushString(result) return 1 @@ -48,12 +53,17 @@ func htmlSpecialChars(state *luajit.State) int { // htmlEntities is a more comprehensive version of htmlSpecialChars func htmlEntities(state *luajit.State) int { - if !state.IsString(1) { + if err := state.CheckExactArgs(1); err != nil { + state.PushNil() + return 1 + } + + input, err := state.SafeToString(1) + if err != nil { state.PushNil() return 1 } - input := state.ToString(1) // First use HTML escape for standard entities result := html.EscapeString(input) @@ -86,12 +96,17 @@ func htmlEntities(state *luajit.State) int { // base64Encode encodes a string to base64 func base64Encode(state *luajit.State) int { - if !state.IsString(1) { + if err := state.CheckExactArgs(1); err != nil { + state.PushNil() + return 1 + } + + input, err := state.SafeToString(1) + if err != nil { state.PushNil() return 1 } - input := state.ToString(1) result := base64.StdEncoding.EncodeToString([]byte(input)) state.PushString(result) return 1 @@ -99,12 +114,17 @@ func base64Encode(state *luajit.State) int { // base64Decode decodes a base64 string func base64Decode(state *luajit.State) int { - if !state.IsString(1) { + if err := state.CheckExactArgs(1); err != nil { + state.PushNil() + return 1 + } + + input, err := state.SafeToString(1) + if err != nil { state.PushNil() return 1 } - input := state.ToString(1) result, err := base64.StdEncoding.DecodeString(input) if err != nil { state.PushNil()