re-add connection tracking, but simpler this time

This commit is contained in:
Sky Johnson 2025-05-10 18:19:26 -05:00
parent 0581d72065
commit d328015681
2 changed files with 201 additions and 77 deletions

View File

@ -1,5 +1,5 @@
-- Simplified SQLite wrapper
-- Connection is now lightweight, we don't need to track IDs
-- Connection is now lightweight with persistent connection tracking
-- Helper function to handle parameters
local function handle_params(params, ...)
@ -28,35 +28,19 @@ local connection_mt = {
error("connection:query: query must be a string", 2)
end
-- Fast path for no parameters
-- Execute with proper connection tracking
local results, token
if params == nil and select('#', ...) == 0 then
return __sqlite_query(self.db_name, query)
results, token = __sqlite_query(self.db_name, query, nil, self.conn_token)
elseif type(params) == "table" then
results, token = __sqlite_query(self.db_name, query, params, self.conn_token)
else
local args = {params, ...}
results, token = __sqlite_query(self.db_name, query, args, self.conn_token)
end
-- Handle various parameter types efficiently
if type(params) == "table" then
-- If it's an array-like table with numeric keys
if params[1] ~= nil then
-- For positional parameters, we want to include the required prefix args
local args = {self.db_name, query}
-- Append all parameters
for i=1, #params do
args[i+2] = params[i]
end
return __sqlite_query(unpack(args))
else
-- Named parameters
return __sqlite_query(self.db_name, query, params)
end
else
-- Variadic parameters, combine with first param
local args = {self.db_name, query, params}
local n = select('#', ...)
for i=1, n do
args[i+3] = select(i, ...)
end
return __sqlite_query(unpack(args))
end
self.conn_token = token
return results
end,
-- Execute a statement and return affected rows
@ -65,35 +49,29 @@ local connection_mt = {
error("connection:exec: query must be a string", 2)
end
-- Fast path for no parameters
-- Execute with proper connection tracking
local affected, token
if params == nil and select('#', ...) == 0 then
return __sqlite_exec(self.db_name, query)
affected, token = __sqlite_exec(self.db_name, query, nil, self.conn_token)
elseif type(params) == "table" then
affected, token = __sqlite_exec(self.db_name, query, params, self.conn_token)
else
local args = {params, ...}
affected, token = __sqlite_exec(self.db_name, query, args, self.conn_token)
end
-- Handle various parameter types efficiently
if type(params) == "table" then
-- If it's an array-like table with numeric keys
if params[1] ~= nil then
-- For positional parameters, we want to include the required prefix args
local args = {self.db_name, query}
-- Append all parameters
for i=1, #params do
args[i+2] = params[i]
end
return __sqlite_exec(unpack(args))
else
-- Named parameters
return __sqlite_exec(self.db_name, query, params)
end
else
-- Variadic parameters, combine with first param
local args = {self.db_name, query, params}
local n = select('#', ...)
for i=1, n do
args[i+3] = select(i, ...)
end
return __sqlite_exec(unpack(args))
self.conn_token = token
return affected
end,
-- Close the connection (release back to pool)
close = function(self)
if self.conn_token then
local success = __sqlite_close(self.conn_token)
self.conn_token = nil
return success
end
return false
end,
-- Insert a row or multiple rows with a single query
@ -451,7 +429,8 @@ return function(db_name)
end
local conn = {
db_name = db_name
db_name = db_name,
conn_token = nil -- Will be populated on first query/exec
}
return setmetatable(conn, connection_mt)

View File

@ -2,10 +2,13 @@ package runner
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"path/filepath"
"strings"
"sync"
"time"
sqlite "zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
@ -20,16 +23,66 @@ var (
dbPools = make(map[string]*sqlitex.Pool)
poolsMu sync.RWMutex
dataDir string
// Connection tracking
activeConns = make(map[string]*TrackedConn)
activeConnMu sync.RWMutex
connTimeout = 5 * time.Minute
)
// TrackedConn holds a connection with usage tracking
type TrackedConn struct {
Conn *sqlite.Conn
Pool *sqlitex.Pool
DBName string
LastUsed time.Time
}
// generateConnToken creates a unique token for connection tracking
func generateConnToken() string {
b := make([]byte, 8)
rand.Read(b)
return base64.URLEncoding.EncodeToString(b)
}
// InitSQLite initializes the SQLite subsystem
func InitSQLite(dir string) {
dataDir = dir
logger.Server("SQLite initialized with data directory: %s", dir)
// Start connection cleanup goroutine
go cleanupIdleConnections()
}
// cleanupIdleConnections periodically checks for and removes idle connections
func cleanupIdleConnections() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
activeConnMu.Lock()
for token, conn := range activeConns {
if conn.LastUsed.Add(connTimeout).Before(now) {
logger.Debug("Closing idle connection: %s (%s)", token, conn.DBName)
conn.Pool.Put(conn.Conn)
delete(activeConns, token)
}
}
activeConnMu.Unlock()
}
}
// CleanupSQLite closes all database connections
func CleanupSQLite() {
activeConnMu.Lock()
for token, conn := range activeConns {
conn.Pool.Put(conn.Conn)
delete(activeConns, token)
}
activeConnMu.Unlock()
poolsMu.Lock()
defer poolsMu.Unlock()
@ -80,6 +133,63 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
return pool, nil
}
// getConnection retrieves or creates a tracked connection
func getConnection(token, dbName string) (*TrackedConn, string, error) {
// If token is provided, try to get existing connection
if token != "" {
activeConnMu.RLock()
conn, exists := activeConns[token]
activeConnMu.RUnlock()
if exists {
conn.LastUsed = time.Now()
return conn, token, nil
}
}
// Token not provided or connection not found, create new
pool, err := getPool(dbName)
if err != nil {
return nil, "", err
}
conn, err := pool.Take(context.Background())
if err != nil {
return nil, "", err
}
// Generate new token
newToken := generateConnToken()
trackedConn := &TrackedConn{
Conn: conn,
Pool: pool,
DBName: dbName,
LastUsed: time.Now(),
}
activeConnMu.Lock()
activeConns[newToken] = trackedConn
activeConnMu.Unlock()
return trackedConn, newToken, nil
}
// releaseConnection releases a connection back to the pool
func releaseConnection(token string) bool {
activeConnMu.Lock()
defer activeConnMu.Unlock()
conn, exists := activeConns[token]
if !exists {
return false
}
conn.Pool.Put(conn.Conn)
delete(activeConns, token)
return true
}
// sqlQuery executes a SQL query and returns results
func sqlQuery(state *luajit.State) int {
// Get required parameters
@ -91,20 +201,20 @@ func sqlQuery(state *luajit.State) int {
dbName := state.ToString(1)
query := state.ToString(2)
// Get connection pool
pool, err := getPool(dbName)
// Get connection token (optional)
var connToken string
if state.GetTop() >= 4 && state.IsString(4) {
connToken = state.ToString(4)
}
// Get connection
trackedConn, newToken, err := getConnection(connToken, dbName)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
// Get a connection from the pool
conn, err := pool.Take(context.Background())
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: failed to get connection: %s", err.Error()))
return -1
}
defer pool.Put(conn)
conn := trackedConn.Conn
// Create execution options
var execOpts sqlitex.ExecOptions
@ -145,6 +255,9 @@ func sqlQuery(state *luajit.State) int {
} else {
// Positional parameters
count := state.GetTop() - 2
if state.IsString(4) {
count-- // Don't include connection token
}
args := make([]any, count)
for i := range count {
idx := i + 3
@ -213,7 +326,10 @@ func sqlQuery(state *luajit.State) int {
state.SetTable(-3)
}
return 1
// Return connection token
state.PushString(newToken)
return 2
}
// sqlExec executes a SQL statement without returning results
@ -227,20 +343,20 @@ func sqlExec(state *luajit.State) int {
dbName := state.ToString(1)
query := state.ToString(2)
// Get connection pool
pool, err := getPool(dbName)
// Get connection token (optional)
var connToken string
if state.GetTop() >= 4 && state.IsString(4) {
connToken = state.ToString(4)
}
// Get connection
trackedConn, newToken, err := getConnection(connToken, dbName)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
// Get a connection from the pool
conn, err := pool.Take(context.Background())
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: failed to get connection: %s", err.Error()))
return -1
}
defer pool.Put(conn)
conn := trackedConn.Conn
// Check if parameters are provided
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
@ -252,7 +368,8 @@ func sqlExec(state *luajit.State) int {
return -1
}
state.PushNumber(float64(conn.Changes()))
return 1
state.PushString(newToken)
return 2
}
// Fast path for simple queries with no parameters
@ -263,7 +380,8 @@ func sqlExec(state *luajit.State) int {
return -1
}
state.PushNumber(float64(conn.Changes()))
return 1
state.PushString(newToken)
return 2
}
// Create execution options for parameterized query
@ -303,6 +421,9 @@ func sqlExec(state *luajit.State) int {
} else {
// Positional parameters
count := state.GetTop() - 2
if state.IsString(4) {
count-- // Don't include connection token
}
args := make([]any, count)
for i := range count {
idx := i + 3
@ -333,8 +454,26 @@ func sqlExec(state *luajit.State) int {
return -1
}
// Return affected rows
// Return affected rows and connection token
state.PushNumber(float64(conn.Changes()))
state.PushString(newToken)
return 2
}
// sqlClose releases a connection back to the pool
func sqlClose(state *luajit.State) int {
if state.GetTop() < 1 || !state.IsString(1) {
state.PushString("sqlite.close: requires connection token")
return -1
}
token := state.ToString(1)
if releaseConnection(token) {
state.PushBoolean(true)
} else {
state.PushBoolean(false)
}
return 1
}
@ -343,5 +482,11 @@ func RegisterSQLiteFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
return err
}
return state.RegisterGoFunction("__sqlite_exec", sqlExec)
if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil {
return err
}
if err := state.RegisterGoFunction("__sqlite_close", sqlClose); err != nil {
return err
}
return nil
}