re-add connection tracking, but simpler this time

This commit is contained in:
Sky Johnson 2025-05-10 18:19:26 -05:00
parent 0581d72065
commit d328015681
2 changed files with 201 additions and 77 deletions

View File

@ -1,5 +1,5 @@
-- Simplified SQLite wrapper -- Simplified SQLite wrapper
-- Connection is now lightweight, we don't need to track IDs -- Connection is now lightweight with persistent connection tracking
-- Helper function to handle parameters -- Helper function to handle parameters
local function handle_params(params, ...) local function handle_params(params, ...)
@ -28,35 +28,19 @@ local connection_mt = {
error("connection:query: query must be a string", 2) error("connection:query: query must be a string", 2)
end end
-- Fast path for no parameters -- Execute with proper connection tracking
local results, token
if params == nil and select('#', ...) == 0 then if params == nil and select('#', ...) == 0 then
return __sqlite_query(self.db_name, query) results, token = __sqlite_query(self.db_name, query, nil, self.conn_token)
elseif type(params) == "table" then
results, token = __sqlite_query(self.db_name, query, params, self.conn_token)
else
local args = {params, ...}
results, token = __sqlite_query(self.db_name, query, args, self.conn_token)
end end
-- Handle various parameter types efficiently self.conn_token = token
if type(params) == "table" then return results
-- If it's an array-like table with numeric keys
if params[1] ~= nil then
-- For positional parameters, we want to include the required prefix args
local args = {self.db_name, query}
-- Append all parameters
for i=1, #params do
args[i+2] = params[i]
end
return __sqlite_query(unpack(args))
else
-- Named parameters
return __sqlite_query(self.db_name, query, params)
end
else
-- Variadic parameters, combine with first param
local args = {self.db_name, query, params}
local n = select('#', ...)
for i=1, n do
args[i+3] = select(i, ...)
end
return __sqlite_query(unpack(args))
end
end, end,
-- Execute a statement and return affected rows -- Execute a statement and return affected rows
@ -65,35 +49,29 @@ local connection_mt = {
error("connection:exec: query must be a string", 2) error("connection:exec: query must be a string", 2)
end end
-- Fast path for no parameters -- Execute with proper connection tracking
local affected, token
if params == nil and select('#', ...) == 0 then if params == nil and select('#', ...) == 0 then
return __sqlite_exec(self.db_name, query) affected, token = __sqlite_exec(self.db_name, query, nil, self.conn_token)
elseif type(params) == "table" then
affected, token = __sqlite_exec(self.db_name, query, params, self.conn_token)
else
local args = {params, ...}
affected, token = __sqlite_exec(self.db_name, query, args, self.conn_token)
end end
-- Handle various parameter types efficiently self.conn_token = token
if type(params) == "table" then return affected
-- If it's an array-like table with numeric keys end,
if params[1] ~= nil then
-- For positional parameters, we want to include the required prefix args -- Close the connection (release back to pool)
local args = {self.db_name, query} close = function(self)
-- Append all parameters if self.conn_token then
for i=1, #params do local success = __sqlite_close(self.conn_token)
args[i+2] = params[i] self.conn_token = nil
end return success
return __sqlite_exec(unpack(args))
else
-- Named parameters
return __sqlite_exec(self.db_name, query, params)
end
else
-- Variadic parameters, combine with first param
local args = {self.db_name, query, params}
local n = select('#', ...)
for i=1, n do
args[i+3] = select(i, ...)
end
return __sqlite_exec(unpack(args))
end end
return false
end, end,
-- Insert a row or multiple rows with a single query -- Insert a row or multiple rows with a single query
@ -451,7 +429,8 @@ return function(db_name)
end end
local conn = { local conn = {
db_name = db_name db_name = db_name,
conn_token = nil -- Will be populated on first query/exec
} }
return setmetatable(conn, connection_mt) return setmetatable(conn, connection_mt)

View File

@ -2,10 +2,13 @@ package runner
import ( import (
"context" "context"
"crypto/rand"
"encoding/base64"
"fmt" "fmt"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
"time"
sqlite "zombiezen.com/go/sqlite" sqlite "zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex" "zombiezen.com/go/sqlite/sqlitex"
@ -20,16 +23,66 @@ var (
dbPools = make(map[string]*sqlitex.Pool) dbPools = make(map[string]*sqlitex.Pool)
poolsMu sync.RWMutex poolsMu sync.RWMutex
dataDir string dataDir string
// Connection tracking
activeConns = make(map[string]*TrackedConn)
activeConnMu sync.RWMutex
connTimeout = 5 * time.Minute
) )
// TrackedConn holds a connection with usage tracking
type TrackedConn struct {
Conn *sqlite.Conn
Pool *sqlitex.Pool
DBName string
LastUsed time.Time
}
// generateConnToken creates a unique token for connection tracking
func generateConnToken() string {
b := make([]byte, 8)
rand.Read(b)
return base64.URLEncoding.EncodeToString(b)
}
// InitSQLite initializes the SQLite subsystem // InitSQLite initializes the SQLite subsystem
func InitSQLite(dir string) { func InitSQLite(dir string) {
dataDir = dir dataDir = dir
logger.Server("SQLite initialized with data directory: %s", dir) logger.Server("SQLite initialized with data directory: %s", dir)
// Start connection cleanup goroutine
go cleanupIdleConnections()
}
// cleanupIdleConnections periodically checks for and removes idle connections
func cleanupIdleConnections() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
activeConnMu.Lock()
for token, conn := range activeConns {
if conn.LastUsed.Add(connTimeout).Before(now) {
logger.Debug("Closing idle connection: %s (%s)", token, conn.DBName)
conn.Pool.Put(conn.Conn)
delete(activeConns, token)
}
}
activeConnMu.Unlock()
}
} }
// CleanupSQLite closes all database connections // CleanupSQLite closes all database connections
func CleanupSQLite() { func CleanupSQLite() {
activeConnMu.Lock()
for token, conn := range activeConns {
conn.Pool.Put(conn.Conn)
delete(activeConns, token)
}
activeConnMu.Unlock()
poolsMu.Lock() poolsMu.Lock()
defer poolsMu.Unlock() defer poolsMu.Unlock()
@ -80,6 +133,63 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
return pool, nil return pool, nil
} }
// getConnection retrieves or creates a tracked connection
func getConnection(token, dbName string) (*TrackedConn, string, error) {
// If token is provided, try to get existing connection
if token != "" {
activeConnMu.RLock()
conn, exists := activeConns[token]
activeConnMu.RUnlock()
if exists {
conn.LastUsed = time.Now()
return conn, token, nil
}
}
// Token not provided or connection not found, create new
pool, err := getPool(dbName)
if err != nil {
return nil, "", err
}
conn, err := pool.Take(context.Background())
if err != nil {
return nil, "", err
}
// Generate new token
newToken := generateConnToken()
trackedConn := &TrackedConn{
Conn: conn,
Pool: pool,
DBName: dbName,
LastUsed: time.Now(),
}
activeConnMu.Lock()
activeConns[newToken] = trackedConn
activeConnMu.Unlock()
return trackedConn, newToken, nil
}
// releaseConnection releases a connection back to the pool
func releaseConnection(token string) bool {
activeConnMu.Lock()
defer activeConnMu.Unlock()
conn, exists := activeConns[token]
if !exists {
return false
}
conn.Pool.Put(conn.Conn)
delete(activeConns, token)
return true
}
// sqlQuery executes a SQL query and returns results // sqlQuery executes a SQL query and returns results
func sqlQuery(state *luajit.State) int { func sqlQuery(state *luajit.State) int {
// Get required parameters // Get required parameters
@ -91,20 +201,20 @@ func sqlQuery(state *luajit.State) int {
dbName := state.ToString(1) dbName := state.ToString(1)
query := state.ToString(2) query := state.ToString(2)
// Get connection pool // Get connection token (optional)
pool, err := getPool(dbName) var connToken string
if state.GetTop() >= 4 && state.IsString(4) {
connToken = state.ToString(4)
}
// Get connection
trackedConn, newToken, err := getConnection(connToken, dbName)
if err != nil { if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1 return -1
} }
// Get a connection from the pool conn := trackedConn.Conn
conn, err := pool.Take(context.Background())
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: failed to get connection: %s", err.Error()))
return -1
}
defer pool.Put(conn)
// Create execution options // Create execution options
var execOpts sqlitex.ExecOptions var execOpts sqlitex.ExecOptions
@ -145,6 +255,9 @@ func sqlQuery(state *luajit.State) int {
} else { } else {
// Positional parameters // Positional parameters
count := state.GetTop() - 2 count := state.GetTop() - 2
if state.IsString(4) {
count-- // Don't include connection token
}
args := make([]any, count) args := make([]any, count)
for i := range count { for i := range count {
idx := i + 3 idx := i + 3
@ -213,7 +326,10 @@ func sqlQuery(state *luajit.State) int {
state.SetTable(-3) state.SetTable(-3)
} }
return 1 // Return connection token
state.PushString(newToken)
return 2
} }
// sqlExec executes a SQL statement without returning results // sqlExec executes a SQL statement without returning results
@ -227,20 +343,20 @@ func sqlExec(state *luajit.State) int {
dbName := state.ToString(1) dbName := state.ToString(1)
query := state.ToString(2) query := state.ToString(2)
// Get connection pool // Get connection token (optional)
pool, err := getPool(dbName) var connToken string
if state.GetTop() >= 4 && state.IsString(4) {
connToken = state.ToString(4)
}
// Get connection
trackedConn, newToken, err := getConnection(connToken, dbName)
if err != nil { if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1 return -1
} }
// Get a connection from the pool conn := trackedConn.Conn
conn, err := pool.Take(context.Background())
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: failed to get connection: %s", err.Error()))
return -1
}
defer pool.Put(conn)
// Check if parameters are provided // Check if parameters are provided
hasParams := state.GetTop() >= 3 && !state.IsNil(3) hasParams := state.GetTop() >= 3 && !state.IsNil(3)
@ -252,7 +368,8 @@ func sqlExec(state *luajit.State) int {
return -1 return -1
} }
state.PushNumber(float64(conn.Changes())) state.PushNumber(float64(conn.Changes()))
return 1 state.PushString(newToken)
return 2
} }
// Fast path for simple queries with no parameters // Fast path for simple queries with no parameters
@ -263,7 +380,8 @@ func sqlExec(state *luajit.State) int {
return -1 return -1
} }
state.PushNumber(float64(conn.Changes())) state.PushNumber(float64(conn.Changes()))
return 1 state.PushString(newToken)
return 2
} }
// Create execution options for parameterized query // Create execution options for parameterized query
@ -303,6 +421,9 @@ func sqlExec(state *luajit.State) int {
} else { } else {
// Positional parameters // Positional parameters
count := state.GetTop() - 2 count := state.GetTop() - 2
if state.IsString(4) {
count-- // Don't include connection token
}
args := make([]any, count) args := make([]any, count)
for i := range count { for i := range count {
idx := i + 3 idx := i + 3
@ -333,8 +454,26 @@ func sqlExec(state *luajit.State) int {
return -1 return -1
} }
// Return affected rows // Return affected rows and connection token
state.PushNumber(float64(conn.Changes())) state.PushNumber(float64(conn.Changes()))
state.PushString(newToken)
return 2
}
// sqlClose releases a connection back to the pool
func sqlClose(state *luajit.State) int {
if state.GetTop() < 1 || !state.IsString(1) {
state.PushString("sqlite.close: requires connection token")
return -1
}
token := state.ToString(1)
if releaseConnection(token) {
state.PushBoolean(true)
} else {
state.PushBoolean(false)
}
return 1 return 1
} }
@ -343,5 +482,11 @@ func RegisterSQLiteFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil { if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
return err return err
} }
return state.RegisterGoFunction("__sqlite_exec", sqlExec) if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil {
return err
}
if err := state.RegisterGoFunction("__sqlite_close", sqlClose); err != nil {
return err
}
return nil
} }