fix lots of luajit api regressions

This commit is contained in:
Sky Johnson 2025-06-02 22:18:54 -05:00
parent 1ad3059ff0
commit 61f66d6594
5 changed files with 134 additions and 110 deletions

2
go.mod
View File

@ -5,7 +5,7 @@ go 1.24.1
require ( require (
git.sharkk.net/Go/LRU v1.0.0 git.sharkk.net/Go/LRU v1.0.0
git.sharkk.net/Sharkk/Fin v1.2.0 git.sharkk.net/Sharkk/Fin v1.2.0
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.0 git.sharkk.net/Sky/LuaJIT-to-Go v0.5.1
github.com/VictoriaMetrics/fastcache v1.12.4 github.com/VictoriaMetrics/fastcache v1.12.4
github.com/alexedwards/argon2id v1.0.0 github.com/alexedwards/argon2id v1.0.0
github.com/deneonet/benc v1.1.8 github.com/deneonet/benc v1.1.8

4
go.sum
View File

@ -2,8 +2,8 @@ git.sharkk.net/Go/LRU v1.0.0 h1:/KqdRVhHldi23aVfQZ4ss6vhCWZqA3vFiQyf1MJPpQc=
git.sharkk.net/Go/LRU v1.0.0/go.mod h1:8tdTyl85mss9a+KKwo+Wj9gKHOizhfLfpJhz1ltYz50= git.sharkk.net/Go/LRU v1.0.0/go.mod h1:8tdTyl85mss9a+KKwo+Wj9gKHOizhfLfpJhz1ltYz50=
git.sharkk.net/Sharkk/Fin v1.2.0 h1:axhme8vHRYoaB3us7PNfXzXxKOxhpS5BMuNpN8ESe6U= git.sharkk.net/Sharkk/Fin v1.2.0 h1:axhme8vHRYoaB3us7PNfXzXxKOxhpS5BMuNpN8ESe6U=
git.sharkk.net/Sharkk/Fin v1.2.0/go.mod h1:ca0Ej9yCM/vHh1o3YMvBZspme3EtbwoEL2UXN5UPXMo= git.sharkk.net/Sharkk/Fin v1.2.0/go.mod h1:ca0Ej9yCM/vHh1o3YMvBZspme3EtbwoEL2UXN5UPXMo=
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.0 h1:WzIKbNIoP+P56n7EmkD9V1QZJUNMbTm3cJj2jc5qUfI= git.sharkk.net/Sky/LuaJIT-to-Go v0.5.1 h1:e9rby0xJs8m2SAPv0di/LplDok88UyjcNjKu8S4d1BY=
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.0/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8= git.sharkk.net/Sky/LuaJIT-to-Go v0.5.1/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8=
github.com/VictoriaMetrics/fastcache v1.12.4 h1:2xvmwZBW+9QtHsXggfzAZRs1FZWCsBs8QDg22bMidf0= github.com/VictoriaMetrics/fastcache v1.12.4 h1:2xvmwZBW+9QtHsXggfzAZRs1FZWCsBs8QDg22bMidf0=
github.com/VictoriaMetrics/fastcache v1.12.4/go.mod h1:K+JGPBn0sueFlLjZ8rcVM0cKkWKNElKyQXmw57QOoYI= github.com/VictoriaMetrics/fastcache v1.12.4/go.mod h1:K+JGPBn0sueFlLjZ8rcVM0cKkWKNElKyQXmw57QOoYI=
github.com/alexedwards/argon2id v1.0.0 h1:wJzDx66hqWX7siL/SRUmgz3F8YMrd/nfX/xHHcQQP0w= github.com/alexedwards/argon2id v1.0.0 h1:wJzDx66hqWX7siL/SRUmgz3F8YMrd/nfX/xHHcQQP0w=

View File

@ -1,5 +1,5 @@
--[[ --[[
sandbox.lua sandbox.lua - Rewritten with global context storage
]]-- ]]--
__http_response = {} __http_response = {}
@ -8,6 +8,9 @@ __module_bytecode = {}
__ready_modules = {} __ready_modules = {}
__EXIT_SENTINEL = {} -- Unique object for exit identification __EXIT_SENTINEL = {} -- Unique object for exit identification
-- Global context storage for reliable access
local _current_ctx = nil
-- ====================================================================== -- ======================================================================
-- CORE SANDBOX FUNCTIONALITY -- CORE SANDBOX FUNCTIONALITY
-- ====================================================================== -- ======================================================================
@ -38,17 +41,20 @@ end
-- Execute script with clean environment -- Execute script with clean environment
function __execute_script(fn, ctx) function __execute_script(fn, ctx)
__http_response = nil __http_response = nil
_current_ctx = ctx -- Store globally for function access
local env = __create_env(ctx) local env = __create_env(ctx)
env.exit = exit env.exit = exit
setfenv(fn, env) setfenv(fn, env)
local ok, result = pcall(fn) local ok, result = pcall(fn)
_current_ctx = nil -- Clean up after execution
if not ok then if not ok then
if result == __EXIT_SENTINEL then if result == __EXIT_SENTINEL then
return return
end end
error(result, 0) error(result, 0)
end end
@ -258,14 +264,13 @@ function cookie_get(name)
error("cookie_get: name must be a string", 2) error("cookie_get: name must be a string", 2)
end end
local env = getfenv(2) if _current_ctx then
if _current_ctx.cookies then
if env.ctx and env.ctx.cookies then return _current_ctx.cookies[name]
return env.ctx.cookies[name] end
end if _current_ctx._request_cookies then
return _current_ctx._request_cookies[name]
if env.ctx and env.ctx._request_cookies then end
return env.ctx._request_cookies[name]
end end
return nil return nil
@ -289,10 +294,8 @@ function session_get(key)
error("session_get: key must be a string", 2) error("session_get: key must be a string", 2)
end end
local env = getfenv(2) if _current_ctx and _current_ctx.session and _current_ctx.session.data then
return _current_ctx.session.data[key]
if env.ctx and env.ctx.session and env.ctx.session.data then
return env.ctx.session.data[key]
end end
return nil return nil
@ -302,7 +305,7 @@ function session_set(key, value)
if type(key) ~= "string" then if type(key) ~= "string" then
error("session_set: key must be a string", 2) error("session_set: key must be a string", 2)
end end
if type(value) == nil then if value == nil then
error("session_set: value cannot be nil", 2) error("session_set: value cannot be nil", 2)
end end
@ -310,30 +313,29 @@ function session_set(key, value)
resp.session = resp.session or {} resp.session = resp.session or {}
resp.session[key] = value resp.session[key] = value
local env = getfenv(2) -- Update current context session data
if env.ctx and env.ctx.session and env.ctx.session.data then if _current_ctx and _current_ctx.session and _current_ctx.session.data then
env.ctx.session.data[key] = value _current_ctx.session.data[key] = value
end end
end end
function session_id() function session_id()
local env = getfenv(2) if _current_ctx and _current_ctx.session then
return _current_ctx.session.id
if env.ctx and env.ctx.session then
return env.ctx.session.id
end end
return nil return nil
end end
function session_get_all() function session_get_all()
local env = getfenv(2) if _current_ctx and _current_ctx.session and _current_ctx.session.data then
-- Return a copy to prevent modification
if env.ctx and env.ctx.session then local copy = {}
return env.ctx.session.data for k, v in pairs(_current_ctx.session.data) do
copy[k] = v
end
return copy
end end
return {}
return nil
end end
function session_delete(key) function session_delete(key)
@ -345,17 +347,16 @@ function session_delete(key)
resp.session = resp.session or {} resp.session = resp.session or {}
resp.session[key] = "__SESSION_DELETE_MARKER__" resp.session[key] = "__SESSION_DELETE_MARKER__"
local env = getfenv(2) -- Update current context
if env.ctx and env.ctx.session and env.ctx.session.data then if _current_ctx and _current_ctx.session and _current_ctx.session.data then
env.ctx.session.data[key] = nil _current_ctx.session.data[key] = nil
end end
end end
function session_clear() function session_clear()
local env = getfenv(2) if _current_ctx and _current_ctx.session and _current_ctx.session.data then
if env.ctx and env.ctx.session and env.ctx.session.data then for k, _ in pairs(_current_ctx.session.data) do
for k, _ in pairs(env.ctx.session.data) do _current_ctx.session.data[k] = nil
env.ctx.session.data[k] = nil
end end
end end
@ -384,11 +385,7 @@ function csrf_field()
end end
function csrf_validate() function csrf_validate()
local env = getfenv(2) local token = session_get("_csrf_token")
local token = false
if env.ctx and env.ctx.session and env.ctx.session.data then
token = env.ctx.session.data["_csrf_token"]
end
if not token then if not token then
http_set_status(403) http_set_status(403)
@ -397,13 +394,13 @@ function csrf_validate()
end end
local request_token = nil local request_token = nil
if env.ctx and env.ctx.form then if _current_ctx and _current_ctx.form then
request_token = env.ctx.form._csrf_token request_token = _current_ctx.form._csrf_token
end end
if not request_token and env.ctx and env.ctx._request_headers then if not request_token and _current_ctx and _current_ctx._request_headers then
request_token = env.ctx._request_headers["x-csrf-token"] or request_token = _current_ctx._request_headers["x-csrf-token"] or
env.ctx._request_headers["csrf-token"] _current_ctx._request_headers["csrf-token"]
end end
if not request_token or request_token ~= token then if not request_token or request_token ~= token then

View File

@ -177,12 +177,15 @@ func extractHTTPResponseData(state *luajit.State, response *Response) {
// Extract headers using ForEachTableKV // Extract headers using ForEachTableKV
if headerTable, ok := state.GetFieldTable(-1, "headers"); ok { if headerTable, ok := state.GetFieldTable(-1, "headers"); ok {
if headers, ok := headerTable.(map[string]any); ok { switch headers := headerTable.(type) {
case map[string]any:
for k, v := range headers { for k, v := range headers {
if str, ok := v.(string); ok { if str, ok := v.(string); ok {
response.Headers[k] = str response.Headers[k] = str
} }
} }
case map[string]string:
maps.Copy(response.Headers, headers)
} }
} }
@ -207,8 +210,19 @@ func extractHTTPResponseData(state *luajit.State, response *Response) {
// Extract session data // Extract session data
if session, ok := state.GetFieldTable(-1, "session"); ok { if session, ok := state.GetFieldTable(-1, "session"); ok {
if sessMap, ok := session.(map[string]any); ok { switch sessMap := session.(type) {
case map[string]any:
maps.Copy(response.SessionData, sessMap) maps.Copy(response.SessionData, sessMap)
case map[string]string:
for k, v := range sessMap {
response.SessionData[k] = v
}
case map[string]int:
for k, v := range sessMap {
response.SessionData[k] = v
}
default:
logger.Debugf("Unexpected session type: %T", session)
} }
} }

View File

@ -21,24 +21,21 @@ var (
dbPools = make(map[string]*sqlitex.Pool) dbPools = make(map[string]*sqlitex.Pool)
poolsMu sync.RWMutex poolsMu sync.RWMutex
dataDir string dataDir string
poolSize = 8 // Default, will be set to match runner pool size poolSize = 8
connTimeout = 5 * time.Second connTimeout = 5 * time.Second
) )
// InitSQLite initializes the SQLite subsystem
func InitSQLite(dir string) { func InitSQLite(dir string) {
dataDir = dir dataDir = dir
logger.Infof("SQLite is g2g! %s", color.Yellow(dir)) logger.Infof("SQLite is g2g! %s", color.Yellow(dir))
} }
// SetSQLitePoolSize sets the pool size to match the runner pool size
func SetSQLitePoolSize(size int) { func SetSQLitePoolSize(size int) {
if size > 0 { if size > 0 {
poolSize = size poolSize = size
} }
} }
// CleanupSQLite closes all database connections
func CleanupSQLite() { func CleanupSQLite() {
poolsMu.Lock() poolsMu.Lock()
defer poolsMu.Unlock() defer poolsMu.Unlock()
@ -53,15 +50,12 @@ func CleanupSQLite() {
logger.Debugf("SQLite connections closed") logger.Debugf("SQLite connections closed")
} }
// getPool returns a connection pool for the database
func getPool(dbName string) (*sqlitex.Pool, error) { func getPool(dbName string) (*sqlitex.Pool, error) {
// Validate database name
dbName = filepath.Base(dbName) dbName = filepath.Base(dbName)
if dbName == "" || dbName[0] == '.' { if dbName == "" || dbName[0] == '.' {
return nil, fmt.Errorf("invalid database name") return nil, fmt.Errorf("invalid database name")
} }
// Check for existing pool
poolsMu.RLock() poolsMu.RLock()
pool, exists := dbPools[dbName] pool, exists := dbPools[dbName]
if exists { if exists {
@ -70,21 +64,17 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
} }
poolsMu.RUnlock() poolsMu.RUnlock()
// Create new pool under write lock
poolsMu.Lock() poolsMu.Lock()
defer poolsMu.Unlock() defer poolsMu.Unlock()
// Double-check if a pool was created while waiting for lock
if pool, exists = dbPools[dbName]; exists { if pool, exists = dbPools[dbName]; exists {
return pool, nil return pool, nil
} }
// Create new pool with proper size
dbPath := filepath.Join(dataDir, dbName+".db") dbPath := filepath.Join(dataDir, dbName+".db")
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{ pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
PoolSize: poolSize, PoolSize: poolSize,
PrepareConn: func(conn *sqlite.Conn) error { PrepareConn: func(conn *sqlite.Conn) error {
// Execute PRAGMA statements individually
pragmas := []string{ pragmas := []string{
"PRAGMA journal_mode = WAL", "PRAGMA journal_mode = WAL",
"PRAGMA synchronous = NORMAL", "PRAGMA synchronous = NORMAL",
@ -109,7 +99,6 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
return pool, nil return pool, nil
} }
// sqlQuery executes a SQL query and returns results
func sqlQuery(state *luajit.State) int { func sqlQuery(state *luajit.State) int {
if err := state.CheckMinArgs(2); err != nil { if err := state.CheckMinArgs(2); err != nil {
return state.PushError("sqlite.query: %v", err) return state.PushError("sqlite.query: %v", err)
@ -125,13 +114,11 @@ func sqlQuery(state *luajit.State) int {
return state.PushError("sqlite.query: query must be string") return state.PushError("sqlite.query: query must be string")
} }
// Get pool
pool, err := getPool(dbName) pool, err := getPool(dbName)
if err != nil { if err != nil {
return state.PushError("sqlite.query: %v", err) return state.PushError("sqlite.query: %v", err)
} }
// Get connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), connTimeout) ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel() defer cancel()
@ -141,18 +128,15 @@ func sqlQuery(state *luajit.State) int {
} }
defer pool.Put(conn) defer pool.Put(conn)
// Create execution options
var execOpts sqlitex.ExecOptions var execOpts sqlitex.ExecOptions
rows := make([]map[string]any, 0, 16) rows := make([]any, 0, 16)
// Set up parameters if provided
if state.GetTop() >= 3 && !state.IsNil(3) { if state.GetTop() >= 3 && !state.IsNil(3) {
if err := setupParams(state, 3, &execOpts); err != nil { if err := setupParams(state, 3, &execOpts); err != nil {
return state.PushError("sqlite.query: %v", err) return state.PushError("sqlite.query: %v", err)
} }
} }
// Set up result function
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error { execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
row := make(map[string]any) row := make(map[string]any)
colCount := stmt.ColumnCount() colCount := stmt.ColumnCount()
@ -182,12 +166,10 @@ func sqlQuery(state *luajit.State) int {
return nil return nil
} }
// Execute query
if err := sqlitex.Execute(conn, query, &execOpts); err != nil { if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
return state.PushError("sqlite.query: %v", err) return state.PushError("sqlite.query: %v", err)
} }
// Create result using specific map type and PushValue
if err := state.PushValue(rows); err != nil { if err := state.PushValue(rows); err != nil {
return state.PushError("sqlite.query: %v", err) return state.PushError("sqlite.query: %v", err)
} }
@ -195,7 +177,6 @@ func sqlQuery(state *luajit.State) int {
return 1 return 1
} }
// sqlExec executes a SQL statement without returning results
func sqlExec(state *luajit.State) int { func sqlExec(state *luajit.State) int {
if err := state.CheckMinArgs(2); err != nil { if err := state.CheckMinArgs(2); err != nil {
return state.PushError("sqlite.exec: %v", err) return state.PushError("sqlite.exec: %v", err)
@ -211,13 +192,11 @@ func sqlExec(state *luajit.State) int {
return state.PushError("sqlite.exec: query must be string") return state.PushError("sqlite.exec: query must be string")
} }
// Get pool
pool, err := getPool(dbName) pool, err := getPool(dbName)
if err != nil { if err != nil {
return state.PushError("sqlite.exec: %v", err) return state.PushError("sqlite.exec: %v", err)
} }
// Get connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), connTimeout) ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel() defer cancel()
@ -227,10 +206,8 @@ func sqlExec(state *luajit.State) int {
} }
defer pool.Put(conn) defer pool.Put(conn)
// Check if parameters are provided
hasParams := state.GetTop() >= 3 && !state.IsNil(3) hasParams := state.GetTop() >= 3 && !state.IsNil(3)
// Fast path for multi-statement scripts
if strings.Contains(query, ";") && !hasParams { if strings.Contains(query, ";") && !hasParams {
if err := sqlitex.ExecScript(conn, query); err != nil { if err := sqlitex.ExecScript(conn, query); err != nil {
return state.PushError("sqlite.exec: %v", err) return state.PushError("sqlite.exec: %v", err)
@ -239,7 +216,6 @@ func sqlExec(state *luajit.State) int {
return 1 return 1
} }
// Fast path for simple queries with no parameters
if !hasParams { if !hasParams {
if err := sqlitex.Execute(conn, query, nil); err != nil { if err := sqlitex.Execute(conn, query, nil); err != nil {
return state.PushError("sqlite.exec: %v", err) return state.PushError("sqlite.exec: %v", err)
@ -248,23 +224,19 @@ func sqlExec(state *luajit.State) int {
return 1 return 1
} }
// Create execution options for parameterized query
var execOpts sqlitex.ExecOptions var execOpts sqlitex.ExecOptions
if err := setupParams(state, 3, &execOpts); err != nil { if err := setupParams(state, 3, &execOpts); err != nil {
return state.PushError("sqlite.exec: %v", err) return state.PushError("sqlite.exec: %v", err)
} }
// Execute with parameters
if err := sqlitex.Execute(conn, query, &execOpts); err != nil { if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
return state.PushError("sqlite.exec: %v", err) return state.PushError("sqlite.exec: %v", err)
} }
// Return affected rows
state.PushNumber(float64(conn.Changes())) state.PushNumber(float64(conn.Changes()))
return 1 return 1
} }
// setupParams configures execution options with parameters from Lua
func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error { func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error {
if state.IsTable(paramIndex) { if state.IsTable(paramIndex) {
paramsAny, err := state.SafeToTable(paramIndex) paramsAny, err := state.SafeToTable(paramIndex)
@ -272,25 +244,31 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti
return fmt.Errorf("invalid parameters: %w", err) return fmt.Errorf("invalid parameters: %w", err)
} }
// Type assert to map[string]any switch params := paramsAny.(type) {
params, ok := paramsAny.(map[string]any) case map[string]any:
if !ok { if arr, ok := params[""]; ok {
return fmt.Errorf("parameters must be a table") if arrParams, ok := arr.([]any); ok {
} execOpts.Args = arrParams
} else if floatArr, ok := arr.([]float64); ok {
// Check for array-style params args := make([]any, len(floatArr))
if arr, ok := params[""]; ok { for i, v := range floatArr {
if arrParams, ok := arr.([]any); ok { args[i] = v
execOpts.Args = arrParams }
} else if floatArr, ok := arr.([]float64); ok { execOpts.Args = args
args := make([]any, len(floatArr))
for i, v := range floatArr {
args[i] = v
} }
execOpts.Args = args } else {
named := make(map[string]any, len(params))
for k, v := range params {
if len(k) > 0 && k[0] != ':' {
named[":"+k] = v
} else {
named[k] = v
}
}
execOpts.Named = named
} }
} else {
// Named parameters case map[string]string:
named := make(map[string]any, len(params)) named := make(map[string]any, len(params))
for k, v := range params { for k, v := range params {
if len(k) > 0 && k[0] != ':' { if len(k) > 0 && k[0] != ':' {
@ -300,9 +278,53 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti
} }
} }
execOpts.Named = named execOpts.Named = named
case map[string]int:
named := make(map[string]any, len(params))
for k, v := range params {
if len(k) > 0 && k[0] != ':' {
named[":"+k] = v
} else {
named[k] = v
}
}
execOpts.Named = named
case map[int]any:
named := make(map[string]any, len(params))
for k, v := range params {
named[fmt.Sprintf(":%d", k)] = v
}
execOpts.Named = named
case []any:
execOpts.Args = params
case []string:
args := make([]any, len(params))
for i, v := range params {
args[i] = v
}
execOpts.Args = args
case []int:
args := make([]any, len(params))
for i, v := range params {
args[i] = v
}
execOpts.Args = args
case []float64:
args := make([]any, len(params))
for i, v := range params {
args[i] = v
}
execOpts.Args = args
default:
return fmt.Errorf("unsupported parameter type: %T", params)
} }
} else { } else {
// Positional parameters from stack
count := state.GetTop() - 2 count := state.GetTop() - 2
args := make([]any, count) args := make([]any, count)
for i := range count { for i := range count {
@ -319,7 +341,6 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti
return nil return nil
} }
// sqlGetOne executes a query and returns only the first row
func sqlGetOne(state *luajit.State) int { func sqlGetOne(state *luajit.State) int {
if err := state.CheckMinArgs(2); err != nil { if err := state.CheckMinArgs(2); err != nil {
return state.PushError("sqlite.get_one: %v", err) return state.PushError("sqlite.get_one: %v", err)
@ -335,13 +356,11 @@ func sqlGetOne(state *luajit.State) int {
return state.PushError("sqlite.get_one: query must be string") return state.PushError("sqlite.get_one: query must be string")
} }
// Get pool
pool, err := getPool(dbName) pool, err := getPool(dbName)
if err != nil { if err != nil {
return state.PushError("sqlite.get_one: %v", err) return state.PushError("sqlite.get_one: %v", err)
} }
// Get connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), connTimeout) ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel() defer cancel()
@ -351,21 +370,18 @@ func sqlGetOne(state *luajit.State) int {
} }
defer pool.Put(conn) defer pool.Put(conn)
// Create execution options
var execOpts sqlitex.ExecOptions var execOpts sqlitex.ExecOptions
var result map[string]any var result map[string]any
// Set up parameters if provided
if state.GetTop() >= 3 && !state.IsNil(3) { if state.GetTop() >= 3 && !state.IsNil(3) {
if err := setupParams(state, 3, &execOpts); err != nil { if err := setupParams(state, 3, &execOpts); err != nil {
return state.PushError("sqlite.get_one: %v", err) return state.PushError("sqlite.get_one: %v", err)
} }
} }
// Set up result function to get only first row
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error { execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
if result != nil { if result != nil {
return nil // Already got first row return nil
} }
result = make(map[string]any) result = make(map[string]any)
@ -395,12 +411,10 @@ func sqlGetOne(state *luajit.State) int {
return nil return nil
} }
// Execute query
if err := sqlitex.Execute(conn, query, &execOpts); err != nil { if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
return state.PushError("sqlite.get_one: %v", err) return state.PushError("sqlite.get_one: %v", err)
} }
// Return result or nil if no rows
if result == nil { if result == nil {
state.PushNil() state.PushNil()
} else { } else {
@ -412,7 +426,6 @@ func sqlGetOne(state *luajit.State) int {
return 1 return 1
} }
// RegisterSQLiteFunctions registers SQLite functions with the Lua state
func RegisterSQLiteFunctions(state *luajit.State) error { func RegisterSQLiteFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil { if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
return err return err