From d32801568137bcc2594baaa529ad5f5a8e934b47 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Sat, 10 May 2025 18:19:26 -0500 Subject: [PATCH] re-add connection tracking, but simpler this time --- runner/lua/sqlite.lua | 85 +++++++------------ runner/sqlite.go | 193 ++++++++++++++++++++++++++++++++++++------ 2 files changed, 201 insertions(+), 77 deletions(-) diff --git a/runner/lua/sqlite.lua b/runner/lua/sqlite.lua index 1d82af9..ccfe458 100644 --- a/runner/lua/sqlite.lua +++ b/runner/lua/sqlite.lua @@ -1,5 +1,5 @@ -- Simplified SQLite wrapper --- Connection is now lightweight, we don't need to track IDs +-- Connection is now lightweight with persistent connection tracking -- Helper function to handle parameters local function handle_params(params, ...) @@ -28,35 +28,19 @@ local connection_mt = { error("connection:query: query must be a string", 2) end - -- Fast path for no parameters + -- Execute with proper connection tracking + local results, token if params == nil and select('#', ...) == 0 then - return __sqlite_query(self.db_name, query) + results, token = __sqlite_query(self.db_name, query, nil, self.conn_token) + elseif type(params) == "table" then + results, token = __sqlite_query(self.db_name, query, params, self.conn_token) + else + local args = {params, ...} + results, token = __sqlite_query(self.db_name, query, args, self.conn_token) end - -- Handle various parameter types efficiently - if type(params) == "table" then - -- If it's an array-like table with numeric keys - if params[1] ~= nil then - -- For positional parameters, we want to include the required prefix args - local args = {self.db_name, query} - -- Append all parameters - for i=1, #params do - args[i+2] = params[i] - end - return __sqlite_query(unpack(args)) - else - -- Named parameters - return __sqlite_query(self.db_name, query, params) - end - else - -- Variadic parameters, combine with first param - local args = {self.db_name, query, params} - local n = select('#', ...) - for i=1, n do - args[i+3] = select(i, ...) - end - return __sqlite_query(unpack(args)) - end + self.conn_token = token + return results end, -- Execute a statement and return affected rows @@ -65,35 +49,29 @@ local connection_mt = { error("connection:exec: query must be a string", 2) end - -- Fast path for no parameters + -- Execute with proper connection tracking + local affected, token if params == nil and select('#', ...) == 0 then - return __sqlite_exec(self.db_name, query) + affected, token = __sqlite_exec(self.db_name, query, nil, self.conn_token) + elseif type(params) == "table" then + affected, token = __sqlite_exec(self.db_name, query, params, self.conn_token) + else + local args = {params, ...} + affected, token = __sqlite_exec(self.db_name, query, args, self.conn_token) end - -- Handle various parameter types efficiently - if type(params) == "table" then - -- If it's an array-like table with numeric keys - if params[1] ~= nil then - -- For positional parameters, we want to include the required prefix args - local args = {self.db_name, query} - -- Append all parameters - for i=1, #params do - args[i+2] = params[i] - end - return __sqlite_exec(unpack(args)) - else - -- Named parameters - return __sqlite_exec(self.db_name, query, params) - end - else - -- Variadic parameters, combine with first param - local args = {self.db_name, query, params} - local n = select('#', ...) - for i=1, n do - args[i+3] = select(i, ...) - end - return __sqlite_exec(unpack(args)) + self.conn_token = token + return affected + end, + + -- Close the connection (release back to pool) + close = function(self) + if self.conn_token then + local success = __sqlite_close(self.conn_token) + self.conn_token = nil + return success end + return false end, -- Insert a row or multiple rows with a single query @@ -451,7 +429,8 @@ return function(db_name) end local conn = { - db_name = db_name + db_name = db_name, + conn_token = nil -- Will be populated on first query/exec } return setmetatable(conn, connection_mt) diff --git a/runner/sqlite.go b/runner/sqlite.go index 2bcd95d..e798db2 100644 --- a/runner/sqlite.go +++ b/runner/sqlite.go @@ -2,10 +2,13 @@ package runner import ( "context" + "crypto/rand" + "encoding/base64" "fmt" "path/filepath" "strings" "sync" + "time" sqlite "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" @@ -20,16 +23,66 @@ var ( dbPools = make(map[string]*sqlitex.Pool) poolsMu sync.RWMutex dataDir string + + // Connection tracking + activeConns = make(map[string]*TrackedConn) + activeConnMu sync.RWMutex + connTimeout = 5 * time.Minute ) +// TrackedConn holds a connection with usage tracking +type TrackedConn struct { + Conn *sqlite.Conn + Pool *sqlitex.Pool + DBName string + LastUsed time.Time +} + +// generateConnToken creates a unique token for connection tracking +func generateConnToken() string { + b := make([]byte, 8) + rand.Read(b) + return base64.URLEncoding.EncodeToString(b) +} + // InitSQLite initializes the SQLite subsystem func InitSQLite(dir string) { dataDir = dir logger.Server("SQLite initialized with data directory: %s", dir) + + // Start connection cleanup goroutine + go cleanupIdleConnections() +} + +// cleanupIdleConnections periodically checks for and removes idle connections +func cleanupIdleConnections() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for range ticker.C { + now := time.Now() + + activeConnMu.Lock() + for token, conn := range activeConns { + if conn.LastUsed.Add(connTimeout).Before(now) { + logger.Debug("Closing idle connection: %s (%s)", token, conn.DBName) + conn.Pool.Put(conn.Conn) + delete(activeConns, token) + } + } + activeConnMu.Unlock() + } } // CleanupSQLite closes all database connections func CleanupSQLite() { + activeConnMu.Lock() + for token, conn := range activeConns { + conn.Pool.Put(conn.Conn) + delete(activeConns, token) + } + activeConnMu.Unlock() + poolsMu.Lock() defer poolsMu.Unlock() @@ -80,6 +133,63 @@ func getPool(dbName string) (*sqlitex.Pool, error) { return pool, nil } +// getConnection retrieves or creates a tracked connection +func getConnection(token, dbName string) (*TrackedConn, string, error) { + // If token is provided, try to get existing connection + if token != "" { + activeConnMu.RLock() + conn, exists := activeConns[token] + activeConnMu.RUnlock() + + if exists { + conn.LastUsed = time.Now() + return conn, token, nil + } + } + + // Token not provided or connection not found, create new + pool, err := getPool(dbName) + if err != nil { + return nil, "", err + } + + conn, err := pool.Take(context.Background()) + if err != nil { + return nil, "", err + } + + // Generate new token + newToken := generateConnToken() + + trackedConn := &TrackedConn{ + Conn: conn, + Pool: pool, + DBName: dbName, + LastUsed: time.Now(), + } + + activeConnMu.Lock() + activeConns[newToken] = trackedConn + activeConnMu.Unlock() + + return trackedConn, newToken, nil +} + +// releaseConnection releases a connection back to the pool +func releaseConnection(token string) bool { + activeConnMu.Lock() + defer activeConnMu.Unlock() + + conn, exists := activeConns[token] + if !exists { + return false + } + + conn.Pool.Put(conn.Conn) + delete(activeConns, token) + return true +} + // sqlQuery executes a SQL query and returns results func sqlQuery(state *luajit.State) int { // Get required parameters @@ -91,20 +201,20 @@ func sqlQuery(state *luajit.State) int { dbName := state.ToString(1) query := state.ToString(2) - // Get connection pool - pool, err := getPool(dbName) + // Get connection token (optional) + var connToken string + if state.GetTop() >= 4 && state.IsString(4) { + connToken = state.ToString(4) + } + + // Get connection + trackedConn, newToken, err := getConnection(connToken, dbName) if err != nil { state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) return -1 } - // Get a connection from the pool - conn, err := pool.Take(context.Background()) - if err != nil { - state.PushString(fmt.Sprintf("sqlite.query: failed to get connection: %s", err.Error())) - return -1 - } - defer pool.Put(conn) + conn := trackedConn.Conn // Create execution options var execOpts sqlitex.ExecOptions @@ -145,6 +255,9 @@ func sqlQuery(state *luajit.State) int { } else { // Positional parameters count := state.GetTop() - 2 + if state.IsString(4) { + count-- // Don't include connection token + } args := make([]any, count) for i := range count { idx := i + 3 @@ -213,7 +326,10 @@ func sqlQuery(state *luajit.State) int { state.SetTable(-3) } - return 1 + // Return connection token + state.PushString(newToken) + + return 2 } // sqlExec executes a SQL statement without returning results @@ -227,20 +343,20 @@ func sqlExec(state *luajit.State) int { dbName := state.ToString(1) query := state.ToString(2) - // Get connection pool - pool, err := getPool(dbName) + // Get connection token (optional) + var connToken string + if state.GetTop() >= 4 && state.IsString(4) { + connToken = state.ToString(4) + } + + // Get connection + trackedConn, newToken, err := getConnection(connToken, dbName) if err != nil { - state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) + state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) return -1 } - // Get a connection from the pool - conn, err := pool.Take(context.Background()) - if err != nil { - state.PushString(fmt.Sprintf("sqlite.exec: failed to get connection: %s", err.Error())) - return -1 - } - defer pool.Put(conn) + conn := trackedConn.Conn // Check if parameters are provided hasParams := state.GetTop() >= 3 && !state.IsNil(3) @@ -252,7 +368,8 @@ func sqlExec(state *luajit.State) int { return -1 } state.PushNumber(float64(conn.Changes())) - return 1 + state.PushString(newToken) + return 2 } // Fast path for simple queries with no parameters @@ -263,7 +380,8 @@ func sqlExec(state *luajit.State) int { return -1 } state.PushNumber(float64(conn.Changes())) - return 1 + state.PushString(newToken) + return 2 } // Create execution options for parameterized query @@ -303,6 +421,9 @@ func sqlExec(state *luajit.State) int { } else { // Positional parameters count := state.GetTop() - 2 + if state.IsString(4) { + count-- // Don't include connection token + } args := make([]any, count) for i := range count { idx := i + 3 @@ -333,8 +454,26 @@ func sqlExec(state *luajit.State) int { return -1 } - // Return affected rows + // Return affected rows and connection token state.PushNumber(float64(conn.Changes())) + state.PushString(newToken) + return 2 +} + +// sqlClose releases a connection back to the pool +func sqlClose(state *luajit.State) int { + if state.GetTop() < 1 || !state.IsString(1) { + state.PushString("sqlite.close: requires connection token") + return -1 + } + + token := state.ToString(1) + if releaseConnection(token) { + state.PushBoolean(true) + } else { + state.PushBoolean(false) + } + return 1 } @@ -343,5 +482,11 @@ func RegisterSQLiteFunctions(state *luajit.State) error { if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil { return err } - return state.RegisterGoFunction("__sqlite_exec", sqlExec) + if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil { + return err + } + if err := state.RegisterGoFunction("__sqlite_close", sqlClose); err != nil { + return err + } + return nil }