re-add connection tracking, but simpler this time
This commit is contained in:
parent
0581d72065
commit
d328015681
@ -1,5 +1,5 @@
|
||||
-- Simplified SQLite wrapper
|
||||
-- Connection is now lightweight, we don't need to track IDs
|
||||
-- Connection is now lightweight with persistent connection tracking
|
||||
|
||||
-- Helper function to handle parameters
|
||||
local function handle_params(params, ...)
|
||||
@ -28,35 +28,19 @@ local connection_mt = {
|
||||
error("connection:query: query must be a string", 2)
|
||||
end
|
||||
|
||||
-- Fast path for no parameters
|
||||
-- Execute with proper connection tracking
|
||||
local results, token
|
||||
if params == nil and select('#', ...) == 0 then
|
||||
return __sqlite_query(self.db_name, query)
|
||||
results, token = __sqlite_query(self.db_name, query, nil, self.conn_token)
|
||||
elseif type(params) == "table" then
|
||||
results, token = __sqlite_query(self.db_name, query, params, self.conn_token)
|
||||
else
|
||||
local args = {params, ...}
|
||||
results, token = __sqlite_query(self.db_name, query, args, self.conn_token)
|
||||
end
|
||||
|
||||
-- Handle various parameter types efficiently
|
||||
if type(params) == "table" then
|
||||
-- If it's an array-like table with numeric keys
|
||||
if params[1] ~= nil then
|
||||
-- For positional parameters, we want to include the required prefix args
|
||||
local args = {self.db_name, query}
|
||||
-- Append all parameters
|
||||
for i=1, #params do
|
||||
args[i+2] = params[i]
|
||||
end
|
||||
return __sqlite_query(unpack(args))
|
||||
else
|
||||
-- Named parameters
|
||||
return __sqlite_query(self.db_name, query, params)
|
||||
end
|
||||
else
|
||||
-- Variadic parameters, combine with first param
|
||||
local args = {self.db_name, query, params}
|
||||
local n = select('#', ...)
|
||||
for i=1, n do
|
||||
args[i+3] = select(i, ...)
|
||||
end
|
||||
return __sqlite_query(unpack(args))
|
||||
end
|
||||
self.conn_token = token
|
||||
return results
|
||||
end,
|
||||
|
||||
-- Execute a statement and return affected rows
|
||||
@ -65,35 +49,29 @@ local connection_mt = {
|
||||
error("connection:exec: query must be a string", 2)
|
||||
end
|
||||
|
||||
-- Fast path for no parameters
|
||||
-- Execute with proper connection tracking
|
||||
local affected, token
|
||||
if params == nil and select('#', ...) == 0 then
|
||||
return __sqlite_exec(self.db_name, query)
|
||||
affected, token = __sqlite_exec(self.db_name, query, nil, self.conn_token)
|
||||
elseif type(params) == "table" then
|
||||
affected, token = __sqlite_exec(self.db_name, query, params, self.conn_token)
|
||||
else
|
||||
local args = {params, ...}
|
||||
affected, token = __sqlite_exec(self.db_name, query, args, self.conn_token)
|
||||
end
|
||||
|
||||
-- Handle various parameter types efficiently
|
||||
if type(params) == "table" then
|
||||
-- If it's an array-like table with numeric keys
|
||||
if params[1] ~= nil then
|
||||
-- For positional parameters, we want to include the required prefix args
|
||||
local args = {self.db_name, query}
|
||||
-- Append all parameters
|
||||
for i=1, #params do
|
||||
args[i+2] = params[i]
|
||||
end
|
||||
return __sqlite_exec(unpack(args))
|
||||
else
|
||||
-- Named parameters
|
||||
return __sqlite_exec(self.db_name, query, params)
|
||||
end
|
||||
else
|
||||
-- Variadic parameters, combine with first param
|
||||
local args = {self.db_name, query, params}
|
||||
local n = select('#', ...)
|
||||
for i=1, n do
|
||||
args[i+3] = select(i, ...)
|
||||
end
|
||||
return __sqlite_exec(unpack(args))
|
||||
self.conn_token = token
|
||||
return affected
|
||||
end,
|
||||
|
||||
-- Close the connection (release back to pool)
|
||||
close = function(self)
|
||||
if self.conn_token then
|
||||
local success = __sqlite_close(self.conn_token)
|
||||
self.conn_token = nil
|
||||
return success
|
||||
end
|
||||
return false
|
||||
end,
|
||||
|
||||
-- Insert a row or multiple rows with a single query
|
||||
@ -451,7 +429,8 @@ return function(db_name)
|
||||
end
|
||||
|
||||
local conn = {
|
||||
db_name = db_name
|
||||
db_name = db_name,
|
||||
conn_token = nil -- Will be populated on first query/exec
|
||||
}
|
||||
|
||||
return setmetatable(conn, connection_mt)
|
||||
|
193
runner/sqlite.go
193
runner/sqlite.go
@ -2,10 +2,13 @@ package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
sqlite "zombiezen.com/go/sqlite"
|
||||
"zombiezen.com/go/sqlite/sqlitex"
|
||||
@ -20,16 +23,66 @@ var (
|
||||
dbPools = make(map[string]*sqlitex.Pool)
|
||||
poolsMu sync.RWMutex
|
||||
dataDir string
|
||||
|
||||
// Connection tracking
|
||||
activeConns = make(map[string]*TrackedConn)
|
||||
activeConnMu sync.RWMutex
|
||||
connTimeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
// TrackedConn holds a connection with usage tracking
|
||||
type TrackedConn struct {
|
||||
Conn *sqlite.Conn
|
||||
Pool *sqlitex.Pool
|
||||
DBName string
|
||||
LastUsed time.Time
|
||||
}
|
||||
|
||||
// generateConnToken creates a unique token for connection tracking
|
||||
func generateConnToken() string {
|
||||
b := make([]byte, 8)
|
||||
rand.Read(b)
|
||||
return base64.URLEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
// InitSQLite initializes the SQLite subsystem
|
||||
func InitSQLite(dir string) {
|
||||
dataDir = dir
|
||||
logger.Server("SQLite initialized with data directory: %s", dir)
|
||||
|
||||
// Start connection cleanup goroutine
|
||||
go cleanupIdleConnections()
|
||||
}
|
||||
|
||||
// cleanupIdleConnections periodically checks for and removes idle connections
|
||||
func cleanupIdleConnections() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
|
||||
activeConnMu.Lock()
|
||||
for token, conn := range activeConns {
|
||||
if conn.LastUsed.Add(connTimeout).Before(now) {
|
||||
logger.Debug("Closing idle connection: %s (%s)", token, conn.DBName)
|
||||
conn.Pool.Put(conn.Conn)
|
||||
delete(activeConns, token)
|
||||
}
|
||||
}
|
||||
activeConnMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupSQLite closes all database connections
|
||||
func CleanupSQLite() {
|
||||
activeConnMu.Lock()
|
||||
for token, conn := range activeConns {
|
||||
conn.Pool.Put(conn.Conn)
|
||||
delete(activeConns, token)
|
||||
}
|
||||
activeConnMu.Unlock()
|
||||
|
||||
poolsMu.Lock()
|
||||
defer poolsMu.Unlock()
|
||||
|
||||
@ -80,6 +133,63 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// getConnection retrieves or creates a tracked connection
|
||||
func getConnection(token, dbName string) (*TrackedConn, string, error) {
|
||||
// If token is provided, try to get existing connection
|
||||
if token != "" {
|
||||
activeConnMu.RLock()
|
||||
conn, exists := activeConns[token]
|
||||
activeConnMu.RUnlock()
|
||||
|
||||
if exists {
|
||||
conn.LastUsed = time.Now()
|
||||
return conn, token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Token not provided or connection not found, create new
|
||||
pool, err := getPool(dbName)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
conn, err := pool.Take(context.Background())
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Generate new token
|
||||
newToken := generateConnToken()
|
||||
|
||||
trackedConn := &TrackedConn{
|
||||
Conn: conn,
|
||||
Pool: pool,
|
||||
DBName: dbName,
|
||||
LastUsed: time.Now(),
|
||||
}
|
||||
|
||||
activeConnMu.Lock()
|
||||
activeConns[newToken] = trackedConn
|
||||
activeConnMu.Unlock()
|
||||
|
||||
return trackedConn, newToken, nil
|
||||
}
|
||||
|
||||
// releaseConnection releases a connection back to the pool
|
||||
func releaseConnection(token string) bool {
|
||||
activeConnMu.Lock()
|
||||
defer activeConnMu.Unlock()
|
||||
|
||||
conn, exists := activeConns[token]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
conn.Pool.Put(conn.Conn)
|
||||
delete(activeConns, token)
|
||||
return true
|
||||
}
|
||||
|
||||
// sqlQuery executes a SQL query and returns results
|
||||
func sqlQuery(state *luajit.State) int {
|
||||
// Get required parameters
|
||||
@ -91,20 +201,20 @@ func sqlQuery(state *luajit.State) int {
|
||||
dbName := state.ToString(1)
|
||||
query := state.ToString(2)
|
||||
|
||||
// Get connection pool
|
||||
pool, err := getPool(dbName)
|
||||
// Get connection token (optional)
|
||||
var connToken string
|
||||
if state.GetTop() >= 4 && state.IsString(4) {
|
||||
connToken = state.ToString(4)
|
||||
}
|
||||
|
||||
// Get connection
|
||||
trackedConn, newToken, err := getConnection(connToken, dbName)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Get a connection from the pool
|
||||
conn, err := pool.Take(context.Background())
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.query: failed to get connection: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
defer pool.Put(conn)
|
||||
conn := trackedConn.Conn
|
||||
|
||||
// Create execution options
|
||||
var execOpts sqlitex.ExecOptions
|
||||
@ -145,6 +255,9 @@ func sqlQuery(state *luajit.State) int {
|
||||
} else {
|
||||
// Positional parameters
|
||||
count := state.GetTop() - 2
|
||||
if state.IsString(4) {
|
||||
count-- // Don't include connection token
|
||||
}
|
||||
args := make([]any, count)
|
||||
for i := range count {
|
||||
idx := i + 3
|
||||
@ -213,7 +326,10 @@ func sqlQuery(state *luajit.State) int {
|
||||
state.SetTable(-3)
|
||||
}
|
||||
|
||||
return 1
|
||||
// Return connection token
|
||||
state.PushString(newToken)
|
||||
|
||||
return 2
|
||||
}
|
||||
|
||||
// sqlExec executes a SQL statement without returning results
|
||||
@ -227,20 +343,20 @@ func sqlExec(state *luajit.State) int {
|
||||
dbName := state.ToString(1)
|
||||
query := state.ToString(2)
|
||||
|
||||
// Get connection pool
|
||||
pool, err := getPool(dbName)
|
||||
// Get connection token (optional)
|
||||
var connToken string
|
||||
if state.GetTop() >= 4 && state.IsString(4) {
|
||||
connToken = state.ToString(4)
|
||||
}
|
||||
|
||||
// Get connection
|
||||
trackedConn, newToken, err := getConnection(connToken, dbName)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Get a connection from the pool
|
||||
conn, err := pool.Take(context.Background())
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: failed to get connection: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
defer pool.Put(conn)
|
||||
conn := trackedConn.Conn
|
||||
|
||||
// Check if parameters are provided
|
||||
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
|
||||
@ -252,7 +368,8 @@ func sqlExec(state *luajit.State) int {
|
||||
return -1
|
||||
}
|
||||
state.PushNumber(float64(conn.Changes()))
|
||||
return 1
|
||||
state.PushString(newToken)
|
||||
return 2
|
||||
}
|
||||
|
||||
// Fast path for simple queries with no parameters
|
||||
@ -263,7 +380,8 @@ func sqlExec(state *luajit.State) int {
|
||||
return -1
|
||||
}
|
||||
state.PushNumber(float64(conn.Changes()))
|
||||
return 1
|
||||
state.PushString(newToken)
|
||||
return 2
|
||||
}
|
||||
|
||||
// Create execution options for parameterized query
|
||||
@ -303,6 +421,9 @@ func sqlExec(state *luajit.State) int {
|
||||
} else {
|
||||
// Positional parameters
|
||||
count := state.GetTop() - 2
|
||||
if state.IsString(4) {
|
||||
count-- // Don't include connection token
|
||||
}
|
||||
args := make([]any, count)
|
||||
for i := range count {
|
||||
idx := i + 3
|
||||
@ -333,8 +454,26 @@ func sqlExec(state *luajit.State) int {
|
||||
return -1
|
||||
}
|
||||
|
||||
// Return affected rows
|
||||
// Return affected rows and connection token
|
||||
state.PushNumber(float64(conn.Changes()))
|
||||
state.PushString(newToken)
|
||||
return 2
|
||||
}
|
||||
|
||||
// sqlClose releases a connection back to the pool
|
||||
func sqlClose(state *luajit.State) int {
|
||||
if state.GetTop() < 1 || !state.IsString(1) {
|
||||
state.PushString("sqlite.close: requires connection token")
|
||||
return -1
|
||||
}
|
||||
|
||||
token := state.ToString(1)
|
||||
if releaseConnection(token) {
|
||||
state.PushBoolean(true)
|
||||
} else {
|
||||
state.PushBoolean(false)
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
@ -343,5 +482,11 @@ func RegisterSQLiteFunctions(state *luajit.State) error {
|
||||
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
return state.RegisterGoFunction("__sqlite_exec", sqlExec)
|
||||
if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__sqlite_close", sqlClose); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user