remove all connection tracking

This commit is contained in:
Sky Johnson 2025-05-10 15:23:57 -05:00
parent 8f9a9da5a1
commit 0581d72065
2 changed files with 55 additions and 152 deletions

View File

@ -145,7 +145,6 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*
// Execute with 2 args, 1 result // Execute with 2 args, 1 result
if err := state.Call(2, 1); err != nil { if err := state.Call(2, 1); err != nil {
ReleaseActiveConnections(state)
return nil, fmt.Errorf("script execution failed: %w", err) return nil, fmt.Errorf("script execution failed: %w", err)
} }
@ -159,8 +158,6 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*
extractHTTPResponseData(state, response) extractHTTPResponseData(state, response)
ReleaseActiveConnections(state)
return response, nil return response, nil
} }

View File

@ -2,7 +2,6 @@ package runner
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"path/filepath" "path/filepath"
"strings" "strings"
@ -16,174 +15,69 @@ import (
luajit "git.sharkk.net/Sky/LuaJIT-to-Go" luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
) )
// SQLiteConnection tracks an active connection // DbPools maintains database connection pools
type SQLiteConnection struct { var (
Conn *sqlite.Conn dbPools = make(map[string]*sqlitex.Pool)
Pool *sqlitex.Pool poolsMu sync.RWMutex
} dataDir string
)
// SQLiteManager handles database connections // InitSQLite initializes the SQLite subsystem
type SQLiteManager struct { func InitSQLite(dir string) {
mu sync.RWMutex dataDir = dir
pools map[string]*sqlitex.Pool logger.Server("SQLite initialized with data directory: %s", dir)
activeConns map[string]*SQLiteConnection
dataDir string
}
var sqliteManager *SQLiteManager
// InitSQLite initializes the SQLite manager
func InitSQLite(dataDir string) {
sqliteManager = &SQLiteManager{
pools: make(map[string]*sqlitex.Pool),
activeConns: make(map[string]*SQLiteConnection),
dataDir: dataDir,
}
logger.Server("SQLite initialized with data directory: %s", dataDir)
} }
// CleanupSQLite closes all database connections // CleanupSQLite closes all database connections
func CleanupSQLite() { func CleanupSQLite() {
if sqliteManager == nil { poolsMu.Lock()
return defer poolsMu.Unlock()
}
sqliteManager.mu.Lock() for name, pool := range dbPools {
defer sqliteManager.mu.Unlock()
for id, conn := range sqliteManager.activeConns {
if conn.Pool != nil {
conn.Pool.Put(conn.Conn)
}
delete(sqliteManager.activeConns, id)
}
for name, pool := range sqliteManager.pools {
if err := pool.Close(); err != nil { if err := pool.Close(); err != nil {
logger.Error("Failed to close database %s: %v", name, err) logger.Error("Failed to close database %s: %v", name, err)
} }
} }
sqliteManager.pools = nil dbPools = make(map[string]*sqlitex.Pool)
sqliteManager.activeConns = nil
logger.Debug("SQLite connections closed") logger.Debug("SQLite connections closed")
} }
// ReleaseActiveConnections returns all active connections to their pools // getPool returns a connection pool for the database
func ReleaseActiveConnections(state *luajit.State) { func getPool(dbName string) (*sqlitex.Pool, error) {
if sqliteManager == nil {
return
}
// Get active connections table from Lua
state.GetGlobal("__active_sqlite_connections")
if !state.IsTable(-1) {
state.Pop(1)
return
}
sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock()
// Iterate through active connections
state.PushNil() // Start iteration
for state.Next(-2) {
if state.IsTable(-1) {
state.GetField(-1, "id")
if state.IsString(-1) {
connID := state.ToString(-1)
if conn, exists := sqliteManager.activeConns[connID]; exists {
if conn.Pool != nil {
conn.Pool.Put(conn.Conn)
}
delete(sqliteManager.activeConns, connID)
}
}
state.Pop(1) // Pop connection id
}
state.Pop(1) // Pop value, leave key for next iteration
}
// Clear the active connections table
state.PushNil()
state.SetGlobal("__active_sqlite_connections")
}
// getConnection returns a connection for the database
func getConnection(dbName, connID string) (*sqlite.Conn, error) {
if sqliteManager == nil {
return nil, errors.New("SQLite not initialized")
}
// Validate database name // Validate database name
dbName = filepath.Base(dbName) dbName = filepath.Base(dbName)
if dbName == "" || dbName[0] == '.' { if dbName == "" || dbName[0] == '.' {
return nil, errors.New("invalid database name") return nil, fmt.Errorf("invalid database name")
} }
// Check for existing connection // Check for existing pool
sqliteManager.mu.RLock() poolsMu.RLock()
conn, exists := sqliteManager.activeConns[connID] pool, exists := dbPools[dbName]
if exists { if exists {
sqliteManager.mu.RUnlock() poolsMu.RUnlock()
return conn.Conn, nil return pool, nil
} }
sqliteManager.mu.RUnlock() poolsMu.RUnlock()
// Get or create pool under write lock // Create new pool under write lock
sqliteManager.mu.Lock() poolsMu.Lock()
defer sqliteManager.mu.Unlock() defer poolsMu.Unlock()
// Double-check if a connection was created while waiting for lock // Double-check if a pool was created while waiting for lock
if conn, exists = sqliteManager.activeConns[connID]; exists { if pool, exists = dbPools[dbName]; exists {
return conn.Conn, nil return pool, nil
} }
// Get or create pool // Create new pool
pool, exists := sqliteManager.pools[dbName] dbPath := filepath.Join(dataDir, dbName+".db")
if !exists { pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{})
dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db")
var err error
pool, err = sqlitex.NewPool(dbPath, sqlitex.PoolOptions{})
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
sqliteManager.pools[dbName] = pool
}
// Get a connection
dbConn, err := pool.Take(context.Background())
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get connection from pool: %w", err) return nil, fmt.Errorf("failed to open database: %w", err)
} }
// Store connection dbPools[dbName] = pool
sqliteManager.activeConns[connID] = &SQLiteConnection{ return pool, nil
Conn: dbConn,
Pool: pool,
}
return dbConn, nil
}
// releaseConnection returns a connection to its pool
func releaseConnection(connID string) {
if sqliteManager == nil {
return
}
sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock()
conn, exists := sqliteManager.activeConns[connID]
if !exists {
return
}
if conn.Pool != nil {
conn.Pool.Put(conn.Conn)
}
delete(sqliteManager.activeConns, connID)
} }
// sqlQuery executes a SQL query and returns results // sqlQuery executes a SQL query and returns results
@ -196,15 +90,21 @@ func sqlQuery(state *luajit.State) int {
dbName := state.ToString(1) dbName := state.ToString(1)
query := state.ToString(2) query := state.ToString(2)
connID := fmt.Sprintf("temp_%p", &query)
// Get connection // Get connection pool
conn, err := getConnection(dbName, connID) pool, err := getPool(dbName)
if err != nil { if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1 return -1
} }
defer releaseConnection(connID)
// Get a connection from the pool
conn, err := pool.Take(context.Background())
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: failed to get connection: %s", err.Error()))
return -1
}
defer pool.Put(conn)
// Create execution options // Create execution options
var execOpts sqlitex.ExecOptions var execOpts sqlitex.ExecOptions
@ -326,15 +226,21 @@ func sqlExec(state *luajit.State) int {
dbName := state.ToString(1) dbName := state.ToString(1)
query := state.ToString(2) query := state.ToString(2)
connID := fmt.Sprintf("temp_%p", &query)
// Get connection // Get connection pool
conn, err := getConnection(dbName, connID) pool, err := getPool(dbName)
if err != nil { if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1 return -1
} }
defer releaseConnection(connID)
// Get a connection from the pool
conn, err := pool.Take(context.Background())
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: failed to get connection: %s", err.Error()))
return -1
}
defer pool.Put(conn)
// Check if parameters are provided // Check if parameters are provided
hasParams := state.GetTop() >= 3 && !state.IsNil(3) hasParams := state.GetTop() >= 3 && !state.IsNil(3)