re-add connection tracking, but simpler this time
This commit is contained in:
parent
0581d72065
commit
d328015681
@ -1,5 +1,5 @@
|
|||||||
-- Simplified SQLite wrapper
|
-- Simplified SQLite wrapper
|
||||||
-- Connection is now lightweight, we don't need to track IDs
|
-- Connection is now lightweight with persistent connection tracking
|
||||||
|
|
||||||
-- Helper function to handle parameters
|
-- Helper function to handle parameters
|
||||||
local function handle_params(params, ...)
|
local function handle_params(params, ...)
|
||||||
@ -28,35 +28,19 @@ local connection_mt = {
|
|||||||
error("connection:query: query must be a string", 2)
|
error("connection:query: query must be a string", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Fast path for no parameters
|
-- Execute with proper connection tracking
|
||||||
|
local results, token
|
||||||
if params == nil and select('#', ...) == 0 then
|
if params == nil and select('#', ...) == 0 then
|
||||||
return __sqlite_query(self.db_name, query)
|
results, token = __sqlite_query(self.db_name, query, nil, self.conn_token)
|
||||||
|
elseif type(params) == "table" then
|
||||||
|
results, token = __sqlite_query(self.db_name, query, params, self.conn_token)
|
||||||
|
else
|
||||||
|
local args = {params, ...}
|
||||||
|
results, token = __sqlite_query(self.db_name, query, args, self.conn_token)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Handle various parameter types efficiently
|
self.conn_token = token
|
||||||
if type(params) == "table" then
|
return results
|
||||||
-- If it's an array-like table with numeric keys
|
|
||||||
if params[1] ~= nil then
|
|
||||||
-- For positional parameters, we want to include the required prefix args
|
|
||||||
local args = {self.db_name, query}
|
|
||||||
-- Append all parameters
|
|
||||||
for i=1, #params do
|
|
||||||
args[i+2] = params[i]
|
|
||||||
end
|
|
||||||
return __sqlite_query(unpack(args))
|
|
||||||
else
|
|
||||||
-- Named parameters
|
|
||||||
return __sqlite_query(self.db_name, query, params)
|
|
||||||
end
|
|
||||||
else
|
|
||||||
-- Variadic parameters, combine with first param
|
|
||||||
local args = {self.db_name, query, params}
|
|
||||||
local n = select('#', ...)
|
|
||||||
for i=1, n do
|
|
||||||
args[i+3] = select(i, ...)
|
|
||||||
end
|
|
||||||
return __sqlite_query(unpack(args))
|
|
||||||
end
|
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Execute a statement and return affected rows
|
-- Execute a statement and return affected rows
|
||||||
@ -65,35 +49,29 @@ local connection_mt = {
|
|||||||
error("connection:exec: query must be a string", 2)
|
error("connection:exec: query must be a string", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Fast path for no parameters
|
-- Execute with proper connection tracking
|
||||||
|
local affected, token
|
||||||
if params == nil and select('#', ...) == 0 then
|
if params == nil and select('#', ...) == 0 then
|
||||||
return __sqlite_exec(self.db_name, query)
|
affected, token = __sqlite_exec(self.db_name, query, nil, self.conn_token)
|
||||||
|
elseif type(params) == "table" then
|
||||||
|
affected, token = __sqlite_exec(self.db_name, query, params, self.conn_token)
|
||||||
|
else
|
||||||
|
local args = {params, ...}
|
||||||
|
affected, token = __sqlite_exec(self.db_name, query, args, self.conn_token)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Handle various parameter types efficiently
|
self.conn_token = token
|
||||||
if type(params) == "table" then
|
return affected
|
||||||
-- If it's an array-like table with numeric keys
|
end,
|
||||||
if params[1] ~= nil then
|
|
||||||
-- For positional parameters, we want to include the required prefix args
|
-- Close the connection (release back to pool)
|
||||||
local args = {self.db_name, query}
|
close = function(self)
|
||||||
-- Append all parameters
|
if self.conn_token then
|
||||||
for i=1, #params do
|
local success = __sqlite_close(self.conn_token)
|
||||||
args[i+2] = params[i]
|
self.conn_token = nil
|
||||||
end
|
return success
|
||||||
return __sqlite_exec(unpack(args))
|
|
||||||
else
|
|
||||||
-- Named parameters
|
|
||||||
return __sqlite_exec(self.db_name, query, params)
|
|
||||||
end
|
|
||||||
else
|
|
||||||
-- Variadic parameters, combine with first param
|
|
||||||
local args = {self.db_name, query, params}
|
|
||||||
local n = select('#', ...)
|
|
||||||
for i=1, n do
|
|
||||||
args[i+3] = select(i, ...)
|
|
||||||
end
|
|
||||||
return __sqlite_exec(unpack(args))
|
|
||||||
end
|
end
|
||||||
|
return false
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Insert a row or multiple rows with a single query
|
-- Insert a row or multiple rows with a single query
|
||||||
@ -451,7 +429,8 @@ return function(db_name)
|
|||||||
end
|
end
|
||||||
|
|
||||||
local conn = {
|
local conn = {
|
||||||
db_name = db_name
|
db_name = db_name,
|
||||||
|
conn_token = nil -- Will be populated on first query/exec
|
||||||
}
|
}
|
||||||
|
|
||||||
return setmetatable(conn, connection_mt)
|
return setmetatable(conn, connection_mt)
|
||||||
|
193
runner/sqlite.go
193
runner/sqlite.go
@ -2,10 +2,13 @@ package runner
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
sqlite "zombiezen.com/go/sqlite"
|
sqlite "zombiezen.com/go/sqlite"
|
||||||
"zombiezen.com/go/sqlite/sqlitex"
|
"zombiezen.com/go/sqlite/sqlitex"
|
||||||
@ -20,16 +23,66 @@ var (
|
|||||||
dbPools = make(map[string]*sqlitex.Pool)
|
dbPools = make(map[string]*sqlitex.Pool)
|
||||||
poolsMu sync.RWMutex
|
poolsMu sync.RWMutex
|
||||||
dataDir string
|
dataDir string
|
||||||
|
|
||||||
|
// Connection tracking
|
||||||
|
activeConns = make(map[string]*TrackedConn)
|
||||||
|
activeConnMu sync.RWMutex
|
||||||
|
connTimeout = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TrackedConn holds a connection with usage tracking
|
||||||
|
type TrackedConn struct {
|
||||||
|
Conn *sqlite.Conn
|
||||||
|
Pool *sqlitex.Pool
|
||||||
|
DBName string
|
||||||
|
LastUsed time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateConnToken creates a unique token for connection tracking
|
||||||
|
func generateConnToken() string {
|
||||||
|
b := make([]byte, 8)
|
||||||
|
rand.Read(b)
|
||||||
|
return base64.URLEncoding.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
// InitSQLite initializes the SQLite subsystem
|
// InitSQLite initializes the SQLite subsystem
|
||||||
func InitSQLite(dir string) {
|
func InitSQLite(dir string) {
|
||||||
dataDir = dir
|
dataDir = dir
|
||||||
logger.Server("SQLite initialized with data directory: %s", dir)
|
logger.Server("SQLite initialized with data directory: %s", dir)
|
||||||
|
|
||||||
|
// Start connection cleanup goroutine
|
||||||
|
go cleanupIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupIdleConnections periodically checks for and removes idle connections
|
||||||
|
func cleanupIdleConnections() {
|
||||||
|
ticker := time.NewTicker(30 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
activeConnMu.Lock()
|
||||||
|
for token, conn := range activeConns {
|
||||||
|
if conn.LastUsed.Add(connTimeout).Before(now) {
|
||||||
|
logger.Debug("Closing idle connection: %s (%s)", token, conn.DBName)
|
||||||
|
conn.Pool.Put(conn.Conn)
|
||||||
|
delete(activeConns, token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
activeConnMu.Unlock()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanupSQLite closes all database connections
|
// CleanupSQLite closes all database connections
|
||||||
func CleanupSQLite() {
|
func CleanupSQLite() {
|
||||||
|
activeConnMu.Lock()
|
||||||
|
for token, conn := range activeConns {
|
||||||
|
conn.Pool.Put(conn.Conn)
|
||||||
|
delete(activeConns, token)
|
||||||
|
}
|
||||||
|
activeConnMu.Unlock()
|
||||||
|
|
||||||
poolsMu.Lock()
|
poolsMu.Lock()
|
||||||
defer poolsMu.Unlock()
|
defer poolsMu.Unlock()
|
||||||
|
|
||||||
@ -80,6 +133,63 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
|
|||||||
return pool, nil
|
return pool, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getConnection retrieves or creates a tracked connection
|
||||||
|
func getConnection(token, dbName string) (*TrackedConn, string, error) {
|
||||||
|
// If token is provided, try to get existing connection
|
||||||
|
if token != "" {
|
||||||
|
activeConnMu.RLock()
|
||||||
|
conn, exists := activeConns[token]
|
||||||
|
activeConnMu.RUnlock()
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
conn.LastUsed = time.Now()
|
||||||
|
return conn, token, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token not provided or connection not found, create new
|
||||||
|
pool, err := getPool(dbName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := pool.Take(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate new token
|
||||||
|
newToken := generateConnToken()
|
||||||
|
|
||||||
|
trackedConn := &TrackedConn{
|
||||||
|
Conn: conn,
|
||||||
|
Pool: pool,
|
||||||
|
DBName: dbName,
|
||||||
|
LastUsed: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
activeConnMu.Lock()
|
||||||
|
activeConns[newToken] = trackedConn
|
||||||
|
activeConnMu.Unlock()
|
||||||
|
|
||||||
|
return trackedConn, newToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// releaseConnection releases a connection back to the pool
|
||||||
|
func releaseConnection(token string) bool {
|
||||||
|
activeConnMu.Lock()
|
||||||
|
defer activeConnMu.Unlock()
|
||||||
|
|
||||||
|
conn, exists := activeConns[token]
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.Pool.Put(conn.Conn)
|
||||||
|
delete(activeConns, token)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// sqlQuery executes a SQL query and returns results
|
// sqlQuery executes a SQL query and returns results
|
||||||
func sqlQuery(state *luajit.State) int {
|
func sqlQuery(state *luajit.State) int {
|
||||||
// Get required parameters
|
// Get required parameters
|
||||||
@ -91,20 +201,20 @@ func sqlQuery(state *luajit.State) int {
|
|||||||
dbName := state.ToString(1)
|
dbName := state.ToString(1)
|
||||||
query := state.ToString(2)
|
query := state.ToString(2)
|
||||||
|
|
||||||
// Get connection pool
|
// Get connection token (optional)
|
||||||
pool, err := getPool(dbName)
|
var connToken string
|
||||||
|
if state.GetTop() >= 4 && state.IsString(4) {
|
||||||
|
connToken = state.ToString(4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get connection
|
||||||
|
trackedConn, newToken, err := getConnection(connToken, dbName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get a connection from the pool
|
conn := trackedConn.Conn
|
||||||
conn, err := pool.Take(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
state.PushString(fmt.Sprintf("sqlite.query: failed to get connection: %s", err.Error()))
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
defer pool.Put(conn)
|
|
||||||
|
|
||||||
// Create execution options
|
// Create execution options
|
||||||
var execOpts sqlitex.ExecOptions
|
var execOpts sqlitex.ExecOptions
|
||||||
@ -145,6 +255,9 @@ func sqlQuery(state *luajit.State) int {
|
|||||||
} else {
|
} else {
|
||||||
// Positional parameters
|
// Positional parameters
|
||||||
count := state.GetTop() - 2
|
count := state.GetTop() - 2
|
||||||
|
if state.IsString(4) {
|
||||||
|
count-- // Don't include connection token
|
||||||
|
}
|
||||||
args := make([]any, count)
|
args := make([]any, count)
|
||||||
for i := range count {
|
for i := range count {
|
||||||
idx := i + 3
|
idx := i + 3
|
||||||
@ -213,7 +326,10 @@ func sqlQuery(state *luajit.State) int {
|
|||||||
state.SetTable(-3)
|
state.SetTable(-3)
|
||||||
}
|
}
|
||||||
|
|
||||||
return 1
|
// Return connection token
|
||||||
|
state.PushString(newToken)
|
||||||
|
|
||||||
|
return 2
|
||||||
}
|
}
|
||||||
|
|
||||||
// sqlExec executes a SQL statement without returning results
|
// sqlExec executes a SQL statement without returning results
|
||||||
@ -227,20 +343,20 @@ func sqlExec(state *luajit.State) int {
|
|||||||
dbName := state.ToString(1)
|
dbName := state.ToString(1)
|
||||||
query := state.ToString(2)
|
query := state.ToString(2)
|
||||||
|
|
||||||
// Get connection pool
|
// Get connection token (optional)
|
||||||
pool, err := getPool(dbName)
|
var connToken string
|
||||||
|
if state.GetTop() >= 4 && state.IsString(4) {
|
||||||
|
connToken = state.ToString(4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get connection
|
||||||
|
trackedConn, newToken, err := getConnection(connToken, dbName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get a connection from the pool
|
conn := trackedConn.Conn
|
||||||
conn, err := pool.Take(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
state.PushString(fmt.Sprintf("sqlite.exec: failed to get connection: %s", err.Error()))
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
defer pool.Put(conn)
|
|
||||||
|
|
||||||
// Check if parameters are provided
|
// Check if parameters are provided
|
||||||
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
|
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
|
||||||
@ -252,7 +368,8 @@ func sqlExec(state *luajit.State) int {
|
|||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
state.PushNumber(float64(conn.Changes()))
|
state.PushNumber(float64(conn.Changes()))
|
||||||
return 1
|
state.PushString(newToken)
|
||||||
|
return 2
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fast path for simple queries with no parameters
|
// Fast path for simple queries with no parameters
|
||||||
@ -263,7 +380,8 @@ func sqlExec(state *luajit.State) int {
|
|||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
state.PushNumber(float64(conn.Changes()))
|
state.PushNumber(float64(conn.Changes()))
|
||||||
return 1
|
state.PushString(newToken)
|
||||||
|
return 2
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create execution options for parameterized query
|
// Create execution options for parameterized query
|
||||||
@ -303,6 +421,9 @@ func sqlExec(state *luajit.State) int {
|
|||||||
} else {
|
} else {
|
||||||
// Positional parameters
|
// Positional parameters
|
||||||
count := state.GetTop() - 2
|
count := state.GetTop() - 2
|
||||||
|
if state.IsString(4) {
|
||||||
|
count-- // Don't include connection token
|
||||||
|
}
|
||||||
args := make([]any, count)
|
args := make([]any, count)
|
||||||
for i := range count {
|
for i := range count {
|
||||||
idx := i + 3
|
idx := i + 3
|
||||||
@ -333,8 +454,26 @@ func sqlExec(state *luajit.State) int {
|
|||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return affected rows
|
// Return affected rows and connection token
|
||||||
state.PushNumber(float64(conn.Changes()))
|
state.PushNumber(float64(conn.Changes()))
|
||||||
|
state.PushString(newToken)
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
|
||||||
|
// sqlClose releases a connection back to the pool
|
||||||
|
func sqlClose(state *luajit.State) int {
|
||||||
|
if state.GetTop() < 1 || !state.IsString(1) {
|
||||||
|
state.PushString("sqlite.close: requires connection token")
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
token := state.ToString(1)
|
||||||
|
if releaseConnection(token) {
|
||||||
|
state.PushBoolean(true)
|
||||||
|
} else {
|
||||||
|
state.PushBoolean(false)
|
||||||
|
}
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -343,5 +482,11 @@ func RegisterSQLiteFunctions(state *luajit.State) error {
|
|||||||
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
|
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return state.RegisterGoFunction("__sqlite_exec", sqlExec)
|
if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := state.RegisterGoFunction("__sqlite_close", sqlClose); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user