major sqlite bug fix, minor lua state closing fix, add headers to lua ctx

This commit is contained in:
Sky Johnson 2025-05-30 13:24:58 -05:00
parent 5b698f31e4
commit 266da9fd23
4 changed files with 268 additions and 303 deletions

View File

@ -210,10 +210,17 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
luaCtx.Set("host", string(ctx.Host())) luaCtx.Set("host", string(ctx.Host()))
luaCtx.Set("session", sessionMap) luaCtx.Set("session", sessionMap)
// Add headers to context
headers := make(map[string]any)
ctx.Request.Header.VisitAll(func(key, value []byte) {
headers[string(key)] = string(value)
})
luaCtx.Set("headers", headers)
// Handle params // Handle params
if params != nil && params.Count > 0 { if params != nil && params.Count > 0 {
paramMap := s.paramsPool.Get().(map[string]any) paramMap := s.paramsPool.Get().(map[string]any)
for i := 0; i < params.Count; i++ { for i := range params.Count {
paramMap[params.Keys[i]] = params.Values[i] paramMap[params.Keys[i]] = params.Values[i]
} }
luaCtx.Set("params", paramMap) luaCtx.Set("params", paramMap)

View File

@ -16,9 +16,7 @@ local connection_mt = {
end end
local normalized_params = normalize_params(params, ...) local normalized_params = normalize_params(params, ...)
local results, token = __sqlite_query(self.db_name, query, normalized_params, self.conn_token) return __sqlite_query(self.db_name, query, normalized_params)
self.conn_token = token
return results
end, end,
exec = function(self, query, params, ...) exec = function(self, query, params, ...)
@ -27,18 +25,16 @@ local connection_mt = {
end end
local normalized_params = normalize_params(params, ...) local normalized_params = normalize_params(params, ...)
local affected, token = __sqlite_exec(self.db_name, query, normalized_params, self.conn_token) return __sqlite_exec(self.db_name, query, normalized_params)
self.conn_token = token
return affected
end, end,
close = function(self) get_one = function(self, query, params, ...)
if self.conn_token then if type(query) ~= "string" then
local success = __sqlite_close(self.conn_token) error("connection:get_one: query must be a string", 2)
self.conn_token = nil
return success
end end
return false
local normalized_params = normalize_params(params, ...)
return __sqlite_get_one(self.db_name, query, normalized_params)
end, end,
insert = function(self, table_name, data, columns) insert = function(self, table_name, data, columns)
@ -249,20 +245,6 @@ local connection_mt = {
return self:exec(query, normalize_params(params, ...)) return self:exec(query, normalize_params(params, ...))
end, end,
get_one = function(self, query, params, ...)
if type(query) ~= "string" then
error("connection:get_one: query must be a string", 2)
end
local limited_query = query
if not string.contains(query:lower(), "limit") then
limited_query = query .. " LIMIT 1"
end
local results = self:query(limited_query, normalize_params(params, ...))
return results[1]
end,
exists = function(self, table_name, where, params, ...) exists = function(self, table_name, where, params, ...)
if type(table_name) ~= "string" then if type(table_name) ~= "string" then
error("connection:exists: table_name must be a string", 2) error("connection:exists: table_name must be a string", 2)
@ -310,7 +292,6 @@ return function(db_name)
end end
return setmetatable({ return setmetatable({
db_name = db_name, db_name = db_name
conn_token = nil
}, connection_mt) }, connection_mt)
end end

View File

@ -33,7 +33,7 @@ type State struct {
L *luajit.State // The Lua state L *luajit.State // The Lua state
sandbox *Sandbox // Associated sandbox sandbox *Sandbox // Associated sandbox
index int // Index for debugging index int // Index for debugging
inUse bool // Whether the state is currently in use inUse atomic.Bool // Whether the state is currently in use
} }
// Runner runs Lua scripts using a pool of Lua states // Runner runs Lua scripts using a pool of Lua states
@ -115,14 +115,16 @@ func NewRunner(options ...RunnerOption) (*Runner, error) {
InitSQLite(runner.dataDir) InitSQLite(runner.dataDir)
InitFS(runner.fsDir) InitFS(runner.fsDir)
SetSQLitePoolSize(runner.poolSize)
// Initialize states and pool // Initialize states and pool
runner.states = make([]*State, runner.poolSize) runner.states = make([]*State, runner.poolSize)
runner.statePool = make(chan int, runner.poolSize) runner.statePool = make(chan int, runner.poolSize)
// Create and initialize all states // Create and initialize all states
if err := runner.initializeStates(); err != nil { if err := runner.initializeStates(); err != nil {
CleanupSQLite() // Clean up SQLite connections CleanupSQLite()
runner.Close() // Clean up already created states runner.Close()
return nil, err return nil, err
} }
@ -190,7 +192,6 @@ func (r *Runner) createState(index int) (*State, error) {
L: L, L: L,
sandbox: sb, sandbox: sb,
index: index, index: index,
inUse: false,
}, nil }, nil
} }
@ -215,29 +216,26 @@ func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context,
// Got a state // Got a state
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, ctx.Err()
case <-time.After(5 * time.Second): case <-time.After(1 * time.Second):
return nil, ErrTimeout return nil, ErrTimeout
} }
// Get the actual state
state := r.states[stateIndex] state := r.states[stateIndex]
if state == nil { if state == nil {
r.statePool <- stateIndex r.statePool <- stateIndex
return nil, ErrStateNotReady return nil, ErrStateNotReady
} }
// Mark state as in use // Use atomic operations
state.inUse = true state.inUse.Store(true)
// Ensure state is returned to pool when done
defer func() { defer func() {
state.inUse = false state.inUse.Store(false)
if r.isRunning.Load() { if r.isRunning.Load() {
select { select {
case r.statePool <- stateIndex: case r.statePool <- stateIndex:
// State returned to pool
default: default:
// Pool is full or closed // Pool is full or closed, state will be cleaned up by Close()
} }
} }
}() }()
@ -267,21 +265,45 @@ func (r *Runner) Close() error {
r.isRunning.Store(false) r.isRunning.Store(false)
// Drain the state pool // Drain all states from the pool
for { for {
select { select {
case <-r.statePool: case <-r.statePool:
// Drain one state
default: default:
// Pool is empty goto waitForInUse
goto cleanup
} }
} }
cleanup: waitForInUse:
// Clean up all states // Wait for in-use states to finish (with timeout)
timeout := time.Now().Add(10 * time.Second)
for {
allIdle := true
for _, state := range r.states {
if state != nil && state.inUse.Load() {
allIdle = false
break
}
}
if allIdle {
break
}
if time.Now().After(timeout) {
logger.Warning("Timeout waiting for states to finish during shutdown, forcing close")
break
}
time.Sleep(10 * time.Millisecond)
}
// Now safely close all states
for i, state := range r.states { for i, state := range r.states {
if state != nil { if state != nil {
if state.inUse.Load() {
logger.Warning("Force closing state %d that is still in use", i)
}
state.L.Cleanup() state.L.Cleanup()
state.L.Close() state.L.Close()
r.states[i] = nil r.states[i] = nil
@ -310,19 +332,40 @@ func (r *Runner) RefreshStates() error {
for { for {
select { select {
case <-r.statePool: case <-r.statePool:
// Drain one state
default: default:
// Pool is empty goto waitForInUse
goto cleanup
} }
} }
cleanup: waitForInUse:
// Destroy all existing states // Wait for in-use states to finish (with timeout)
timeout := time.Now().Add(10 * time.Second)
for {
allIdle := true
for _, state := range r.states {
if state != nil && state.inUse.Load() {
allIdle = false
break
}
}
if allIdle {
break
}
if time.Now().After(timeout) {
logger.Warning("Timeout waiting for states to finish, forcing refresh")
break
}
time.Sleep(10 * time.Millisecond)
}
// Now safely destroy all states
for i, state := range r.states { for i, state := range r.states {
if state != nil { if state != nil {
if state.inUse { if state.inUse.Load() {
logger.Warning("Attempting to refresh state %d that is in use", i) logger.Warning("Force closing state %d that is still in use", i)
} }
state.L.Cleanup() state.L.Cleanup()
state.L.Close() state.L.Close()
@ -367,7 +410,7 @@ func (r *Runner) RefreshModule(moduleName string) bool {
success := true success := true
for _, state := range r.states { for _, state := range r.states {
if state == nil || state.inUse { if state == nil || state.inUse.Load() {
continue continue
} }
@ -403,7 +446,7 @@ func (r *Runner) GetActiveStateCount() int {
count := 0 count := 0
for _, state := range r.states { for _, state := range r.states {
if state != nil && state.inUse { if state != nil && state.inUse.Load() {
count++ count++
} }
} }
@ -459,10 +502,10 @@ func (r *Runner) RunScriptFile(filePath string) (*Response, error) {
return nil, ErrStateNotReady return nil, ErrStateNotReady
} }
state.inUse = true state.inUse.Store(true)
defer func() { defer func() {
state.inUse = false state.inUse.Store(false)
if r.isRunning.Load() { if r.isRunning.Load() {
select { select {
case r.statePool <- stateIndex: case r.statePool <- stateIndex:

View File

@ -2,8 +2,6 @@ package runner
import ( import (
"context" "context"
"crypto/rand"
"encoding/base64"
"fmt" "fmt"
"path/filepath" "path/filepath"
"strings" "strings"
@ -19,71 +17,29 @@ import (
luajit "git.sharkk.net/Sky/LuaJIT-to-Go" luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
) )
// DbPools maintains database connection pools
var ( var (
dbPools = make(map[string]*sqlitex.Pool) dbPools = make(map[string]*sqlitex.Pool)
poolsMu sync.RWMutex poolsMu sync.RWMutex
dataDir string dataDir string
poolSize = 8 // Default, will be set to match runner pool size
// Connection tracking connTimeout = 5 * time.Second
activeConns = make(map[string]*TrackedConn)
activeConnMu sync.RWMutex
connTimeout = 5 * time.Minute
) )
// TrackedConn holds a connection with usage tracking
type TrackedConn struct {
Conn *sqlite.Conn
Pool *sqlitex.Pool
DBName string
LastUsed time.Time
}
// generateConnToken creates a unique token for connection tracking
func generateConnToken() string {
b := make([]byte, 8)
rand.Read(b)
return base64.URLEncoding.EncodeToString(b)
}
// InitSQLite initializes the SQLite subsystem // InitSQLite initializes the SQLite subsystem
func InitSQLite(dir string) { func InitSQLite(dir string) {
dataDir = dir dataDir = dir
logger.Info("SQLite is g2g! %s", color.Apply(dir, color.Yellow)) logger.Info("SQLite is g2g! %s", color.Apply(dir, color.Yellow))
// Start connection cleanup goroutine
go cleanupIdleConnections()
} }
// cleanupIdleConnections periodically checks for and removes idle connections // SetSQLitePoolSize sets the pool size to match the runner pool size
func cleanupIdleConnections() { func SetSQLitePoolSize(size int) {
ticker := time.NewTicker(30 * time.Second) if size > 0 {
defer ticker.Stop() poolSize = size
for range ticker.C {
now := time.Now()
activeConnMu.Lock()
for token, conn := range activeConns {
if conn.LastUsed.Add(connTimeout).Before(now) {
logger.Debug("Closing idle connection: %s (%s)", token, conn.DBName)
conn.Pool.Put(conn.Conn)
delete(activeConns, token)
}
}
activeConnMu.Unlock()
} }
} }
// CleanupSQLite closes all database connections // CleanupSQLite closes all database connections
func CleanupSQLite() { func CleanupSQLite() {
activeConnMu.Lock()
for token, conn := range activeConns {
conn.Pool.Put(conn.Conn)
delete(activeConns, token)
}
activeConnMu.Unlock()
poolsMu.Lock() poolsMu.Lock()
defer poolsMu.Unlock() defer poolsMu.Unlock()
@ -123,74 +79,36 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
return pool, nil return pool, nil
} }
// Create new pool // Create new pool with proper size
dbPath := filepath.Join(dataDir, dbName+".db") dbPath := filepath.Join(dataDir, dbName+".db")
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{}) pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
PoolSize: poolSize,
PrepareConn: func(conn *sqlite.Conn) error {
// Execute PRAGMA statements individually
pragmas := []string{
"PRAGMA journal_mode = WAL",
"PRAGMA synchronous = NORMAL",
"PRAGMA cache_size = 1000",
"PRAGMA foreign_keys = ON",
"PRAGMA temp_store = MEMORY",
}
for _, pragma := range pragmas {
if err := sqlitex.ExecuteTransient(conn, pragma, nil); err != nil {
return err
}
}
return nil
},
})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err) return nil, fmt.Errorf("failed to open database: %w", err)
} }
dbPools[dbName] = pool dbPools[dbName] = pool
logger.Debug("Created SQLite pool for %s (size: %d)", dbName, poolSize)
return pool, nil return pool, nil
} }
// getConnection retrieves or creates a tracked connection
func getConnection(token, dbName string) (*TrackedConn, string, error) {
// If token is provided, try to get existing connection
if token != "" {
activeConnMu.RLock()
conn, exists := activeConns[token]
activeConnMu.RUnlock()
if exists {
conn.LastUsed = time.Now()
return conn, token, nil
}
}
// Token not provided or connection not found, create new
pool, err := getPool(dbName)
if err != nil {
return nil, "", err
}
conn, err := pool.Take(context.Background())
if err != nil {
return nil, "", err
}
// Generate new token
newToken := generateConnToken()
trackedConn := &TrackedConn{
Conn: conn,
Pool: pool,
DBName: dbName,
LastUsed: time.Now(),
}
activeConnMu.Lock()
activeConns[newToken] = trackedConn
activeConnMu.Unlock()
return trackedConn, newToken, nil
}
// releaseConnection releases a connection back to the pool
func releaseConnection(token string) bool {
activeConnMu.Lock()
defer activeConnMu.Unlock()
conn, exists := activeConns[token]
if !exists {
return false
}
conn.Pool.Put(conn.Conn)
delete(activeConns, token)
return true
}
// sqlQuery executes a SQL query and returns results // sqlQuery executes a SQL query and returns results
func sqlQuery(state *luajit.State) int { func sqlQuery(state *luajit.State) int {
// Get required parameters // Get required parameters
@ -202,20 +120,23 @@ func sqlQuery(state *luajit.State) int {
dbName := state.ToString(1) dbName := state.ToString(1)
query := state.ToString(2) query := state.ToString(2)
// Get connection token (optional) // Get pool
var connToken string pool, err := getPool(dbName)
if state.GetTop() >= 4 && state.IsString(4) {
connToken = state.ToString(4)
}
// Get connection
trackedConn, newToken, err := getConnection(connToken, dbName)
if err != nil { if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1 return -1
} }
conn := trackedConn.Conn // Get connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: connection timeout: %s", err.Error()))
return -1
}
defer pool.Put(conn)
// Create execution options // Create execution options
var execOpts sqlitex.ExecOptions var execOpts sqlitex.ExecOptions
@ -223,65 +144,10 @@ func sqlQuery(state *luajit.State) int {
// Set up parameters if provided // Set up parameters if provided
if state.GetTop() >= 3 && !state.IsNil(3) { if state.GetTop() >= 3 && !state.IsNil(3) {
if state.IsTable(3) { if err := setupParams(state, 3, &execOpts); err != nil {
params, err := state.ToTable(3) state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: invalid parameters: %s", err.Error()))
return -1 return -1
} }
// Check for array-style params
if arr, ok := params[""]; ok {
if arrParams, ok := arr.([]any); ok {
execOpts.Args = arrParams
} else if floatArr, ok := arr.([]float64); ok {
args := make([]any, len(floatArr))
for i, v := range floatArr {
args[i] = v
}
execOpts.Args = args
}
} else {
// Named parameters
named := make(map[string]any, len(params))
for k, v := range params {
if len(k) > 0 && k[0] != ':' {
named[":"+k] = v
} else {
named[k] = v
}
}
execOpts.Named = named
}
} else {
// Positional parameters
count := state.GetTop() - 2
if state.IsString(4) {
count-- // Don't include connection token
}
args := make([]any, count)
for i := range count {
idx := i + 3
switch state.GetType(idx) {
case luajit.TypeNumber:
args[i] = state.ToNumber(idx)
case luajit.TypeString:
args[i] = state.ToString(idx)
case luajit.TypeBoolean:
args[i] = state.ToBoolean(idx)
case luajit.TypeNil:
args[i] = nil
default:
val, err := state.ToValue(idx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: invalid parameter %d: %s", i+1, err.Error()))
return -1
}
args[i] = val
}
}
execOpts.Args = args
}
} }
// Set up result function // Set up result function
@ -300,8 +166,12 @@ func sqlQuery(state *luajit.State) int {
row[colName] = stmt.ColumnText(i) row[colName] = stmt.ColumnText(i)
case sqlite.TypeBlob: case sqlite.TypeBlob:
blobSize := stmt.ColumnLen(i) blobSize := stmt.ColumnLen(i)
if blobSize > 0 {
buf := make([]byte, blobSize) buf := make([]byte, blobSize)
row[colName] = stmt.ColumnBytes(i, buf) row[colName] = stmt.ColumnBytes(i, buf)
} else {
row[colName] = []byte{}
}
case sqlite.TypeNull: case sqlite.TypeNull:
row[colName] = nil row[colName] = nil
} }
@ -327,10 +197,7 @@ func sqlQuery(state *luajit.State) int {
state.SetTable(-3) state.SetTable(-3)
} }
// Return connection token return 1
state.PushString(newToken)
return 2
} }
// sqlExec executes a SQL statement without returning results // sqlExec executes a SQL statement without returning results
@ -344,58 +211,73 @@ func sqlExec(state *luajit.State) int {
dbName := state.ToString(1) dbName := state.ToString(1)
query := state.ToString(2) query := state.ToString(2)
// Get connection token (optional) // Get pool
var connToken string pool, err := getPool(dbName)
if state.GetTop() >= 4 && state.IsString(4) {
connToken = state.ToString(4)
}
// Get connection
trackedConn, newToken, err := getConnection(connToken, dbName)
if err != nil { if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1 return -1
} }
conn := trackedConn.Conn // Get connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: connection timeout: %s", err.Error()))
return -1
}
defer pool.Put(conn)
// Check if parameters are provided // Check if parameters are provided
hasParams := state.GetTop() >= 3 && !state.IsNil(3) hasParams := state.GetTop() >= 3 && !state.IsNil(3)
// Fast path for multi-statement scripts - use ExecScript // Fast path for multi-statement scripts
if strings.Contains(query, ";") && !hasParams { if strings.Contains(query, ";") && !hasParams {
if err := sqlitex.ExecScript(conn, query); err != nil { if err := sqlitex.ExecScript(conn, query); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1 return -1
} }
state.PushNumber(float64(conn.Changes())) state.PushNumber(float64(conn.Changes()))
state.PushString(newToken) return 1
return 2
} }
// Fast path for simple queries with no parameters // Fast path for simple queries with no parameters
if !hasParams { if !hasParams {
// Use Execute for simple statements without parameters
if err := sqlitex.Execute(conn, query, nil); err != nil { if err := sqlitex.Execute(conn, query, nil); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1 return -1
} }
state.PushNumber(float64(conn.Changes())) state.PushNumber(float64(conn.Changes()))
state.PushString(newToken) return 1
return 2
} }
// Create execution options for parameterized query // Create execution options for parameterized query
var execOpts sqlitex.ExecOptions var execOpts sqlitex.ExecOptions
if err := setupParams(state, 3, &execOpts); err != nil {
// Set up parameters state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
if state.IsTable(3) {
params, err := state.ToTable(3)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: invalid parameters: %s", err.Error()))
return -1 return -1
} }
// Execute with parameters
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
// Return affected rows
state.PushNumber(float64(conn.Changes()))
return 1
}
// setupParams configures execution options with parameters from Lua
func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error {
if state.IsTable(paramIndex) {
params, err := state.ToTable(paramIndex)
if err != nil {
return fmt.Errorf("invalid parameters: %w", err)
}
// Check for array-style params // Check for array-style params
if arr, ok := params[""]; ok { if arr, ok := params[""]; ok {
if arrParams, ok := arr.([]any); ok { if arrParams, ok := arr.([]any); ok {
@ -420,59 +302,111 @@ func sqlExec(state *luajit.State) int {
execOpts.Named = named execOpts.Named = named
} }
} else { } else {
// Positional parameters // Positional parameters from stack
count := state.GetTop() - 2 count := state.GetTop() - 2
if state.IsString(4) {
count-- // Don't include connection token
}
args := make([]any, count) args := make([]any, count)
for i := range count { for i := range count {
idx := i + 3 idx := i + 3
switch state.GetType(idx) {
case luajit.TypeNumber:
args[i] = state.ToNumber(idx)
case luajit.TypeString:
args[i] = state.ToString(idx)
case luajit.TypeBoolean:
args[i] = state.ToBoolean(idx)
case luajit.TypeNil:
args[i] = nil
default:
val, err := state.ToValue(idx) val, err := state.ToValue(idx)
if err != nil { if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: invalid parameter %d: %s", i+1, err.Error())) return fmt.Errorf("invalid parameter %d: %w", i+1, err)
return -1
} }
args[i] = val args[i] = val
} }
}
execOpts.Args = args execOpts.Args = args
} }
// Execute with parameters return nil
if err := sqlitex.Execute(conn, query, &execOpts); err != nil { }
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
// sqlGetOne executes a query and returns only the first row
func sqlGetOne(state *luajit.State) int {
// Get required parameters
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
state.PushString("sqlite.get_one: requires database name and query")
return -1 return -1
} }
// Return affected rows and connection token dbName := state.ToString(1)
state.PushNumber(float64(conn.Changes())) query := state.ToString(2)
state.PushString(newToken)
return 2
}
// sqlClose releases a connection back to the pool // Get pool
func sqlClose(state *luajit.State) int { pool, err := getPool(dbName)
if state.GetTop() < 1 || !state.IsString(1) { if err != nil {
state.PushString("sqlite.close: requires connection token") state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1 return -1
} }
token := state.ToString(1) // Get connection with timeout
if releaseConnection(token) { ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
state.PushBoolean(true) defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: connection timeout: %s", err.Error()))
return -1
}
defer pool.Put(conn)
// Create execution options
var execOpts sqlitex.ExecOptions
var result map[string]any
// Set up parameters if provided
if state.GetTop() >= 3 && !state.IsNil(3) {
if err := setupParams(state, 3, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
}
}
// Set up result function to get only first row
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
if result != nil {
return nil // Already got first row
}
result = make(map[string]any)
colCount := stmt.ColumnCount()
for i := range colCount {
colName := stmt.ColumnName(i)
switch stmt.ColumnType(i) {
case sqlite.TypeInteger:
result[colName] = stmt.ColumnInt64(i)
case sqlite.TypeFloat:
result[colName] = stmt.ColumnFloat(i)
case sqlite.TypeText:
result[colName] = stmt.ColumnText(i)
case sqlite.TypeBlob:
blobSize := stmt.ColumnLen(i)
if blobSize > 0 {
buf := make([]byte, blobSize)
result[colName] = stmt.ColumnBytes(i, buf)
} else { } else {
state.PushBoolean(false) result[colName] = []byte{}
}
case sqlite.TypeNull:
result[colName] = nil
}
}
return nil
}
// Execute query
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
}
// Return result or nil if no rows
if result == nil {
state.PushNil()
} else {
if err := state.PushTable(result); err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
}
} }
return 1 return 1
@ -486,7 +420,7 @@ func RegisterSQLiteFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil { if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil {
return err return err
} }
if err := state.RegisterGoFunction("__sqlite_close", sqlClose); err != nil { if err := state.RegisterGoFunction("__sqlite_get_one", sqlGetOne); err != nil {
return err return err
} }
return nil return nil