work towards reusable sqlite connections

This commit is contained in:
Sky Johnson 2025-06-05 19:04:05 -05:00
parent 0c4ddd7e3d
commit 4077ac03f1
4 changed files with 98 additions and 42 deletions

View File

@ -18,7 +18,7 @@ local connection_mt = {
end end
local normalized_params = normalize_params(params, ...) local normalized_params = normalize_params(params, ...)
return __sqlite_query(self.db_name, query, normalized_params) return __sqlite_query(self.db_name, query, normalized_params, __STATE_INDEX)
end, end,
exec = function(self, query, params, ...) exec = function(self, query, params, ...)
@ -27,7 +27,7 @@ local connection_mt = {
end end
local normalized_params = normalize_params(params, ...) local normalized_params = normalize_params(params, ...)
return __sqlite_exec(self.db_name, query, normalized_params) return __sqlite_exec(self.db_name, query, normalized_params, __STATE_INDEX)
end, end,
get_one = function(self, query, params, ...) get_one = function(self, query, params, ...)
@ -36,7 +36,7 @@ local connection_mt = {
end end
local normalized_params = normalize_params(params, ...) local normalized_params = normalize_params(params, ...)
return __sqlite_get_one(self.db_name, query, normalized_params) return __sqlite_get_one(self.db_name, query, normalized_params, __STATE_INDEX)
end, end,
insert = function(self, table_name, data, columns) insert = function(self, table_name, data, columns)

View File

@ -122,7 +122,7 @@ func (r *Runner) ExecuteHTTP(bytecode []byte, httpCtx *fasthttp.RequestCtx,
luaCtx := r.buildHTTPContext(httpCtx, params, session) luaCtx := r.buildHTTPContext(httpCtx, params, session)
defer r.releaseHTTPContext(luaCtx) defer r.releaseHTTPContext(luaCtx)
return state.sandbox.Execute(state.L, bytecode, luaCtx) return state.sandbox.Execute(state.L, bytecode, luaCtx, state.index)
} }
// Build Lua context from HTTP request // Build Lua context from HTTP request
@ -463,7 +463,7 @@ func (r *Runner) RunScriptFile(filePath string) (*Response, error) {
ctx.Set("_script_dir", scriptDir) ctx.Set("_script_dir", scriptDir)
// Execute script // Execute script
response, err := state.sandbox.Execute(state.L, bytecode, ctx) response, err := state.sandbox.Execute(state.L, bytecode, ctx, state.index)
if err != nil { if err != nil {
return nil, fmt.Errorf("execution error: %w", err) return nil, fmt.Errorf("execution error: %w", err)
} }

View File

@ -49,7 +49,7 @@ func (s *Sandbox) Setup(state *luajit.State, stateIndex int, verbose bool) error
} }
// Execute runs a Lua script in the sandbox with the given context // Execute runs a Lua script in the sandbox with the given context
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) { func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context, stateIndex int) (*Response, error) {
// Load script and executor // Load script and executor
if err := state.LoadBytecode(bytecode, "script"); err != nil { if err := state.LoadBytecode(bytecode, "script"); err != nil {
return nil, fmt.Errorf("failed to load bytecode: %w", err) return nil, fmt.Errorf("failed to load bytecode: %w", err)
@ -90,6 +90,8 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*
result, _ := state.ToValue(-1) result, _ := state.ToValue(-1)
state.Pop(2) // Clean up state.Pop(2) // Clean up
sqlite.CleanupStateConnection(stateIndex)
var modifiedResponse map[string]any var modifiedResponse map[string]any
var scriptResult any var scriptResult any

View File

@ -23,8 +23,18 @@ var (
dataDir string dataDir string
poolSize = 8 poolSize = 8
connTimeout = 5 * time.Second connTimeout = 5 * time.Second
// Per-state connection cache
stateConns = make(map[string]*stateConn)
stateConnsMu sync.RWMutex
) )
// stateConn tracks a connection and its origin pool
type stateConn struct {
conn *sqlite.Conn
pool *sqlitex.Pool
}
func InitSQLite(dir string) { func InitSQLite(dir string) {
dataDir = dir dataDir = dir
logger.Infof("SQLite is g2g! %s", color.Yellow(dir)) logger.Infof("SQLite is g2g! %s", color.Yellow(dir))
@ -40,6 +50,16 @@ func CleanupSQLite() {
poolsMu.Lock() poolsMu.Lock()
defer poolsMu.Unlock() defer poolsMu.Unlock()
// Return all cached connections to their pools
stateConnsMu.Lock()
for _, sc := range stateConns {
if sc.pool != nil && sc.conn != nil {
sc.pool.Put(sc.conn)
}
}
stateConns = make(map[string]*stateConn)
stateConnsMu.Unlock()
for name, pool := range dbPools { for name, pool := range dbPools {
if err := pool.Close(); err != nil { if err := pool.Close(); err != nil {
logger.Errorf("Failed to close database %s: %v", name, err) logger.Errorf("Failed to close database %s: %v", name, err)
@ -99,8 +119,45 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
return pool, nil return pool, nil
} }
// getStateConnection gets or creates a reusable connection for the state+db
func getStateConnection(stateIndex int, dbName string) (*sqlite.Conn, error) {
connKey := fmt.Sprintf("%d-%s", stateIndex, dbName)
stateConnsMu.RLock()
sc, exists := stateConns[connKey]
stateConnsMu.RUnlock()
if exists && sc.conn != nil {
return sc.conn, nil
}
// Get new connection from pool
pool, err := getPool(dbName)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
return nil, fmt.Errorf("connection timeout: %w", err)
}
// Cache it with pool reference
stateConnsMu.Lock()
stateConns[connKey] = &stateConn{
conn: conn,
pool: pool,
}
stateConnsMu.Unlock()
return conn, nil
}
func sqlQuery(state *luajit.State) int { func sqlQuery(state *luajit.State) int {
if err := state.CheckMinArgs(2); err != nil { if err := state.CheckMinArgs(3); err != nil {
return state.PushError("sqlite.query: %v", err) return state.PushError("sqlite.query: %v", err)
} }
@ -114,24 +171,17 @@ func sqlQuery(state *luajit.State) int {
return state.PushError("sqlite.query: query must be string") return state.PushError("sqlite.query: query must be string")
} }
pool, err := getPool(dbName) stateIndex := int(state.ToNumber(-1))
conn, err := getStateConnection(stateIndex, dbName)
if err != nil { if err != nil {
return state.PushError("sqlite.query: %v", err) return state.PushError("sqlite.query: %v", err)
} }
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
return state.PushError("sqlite.query: connection timeout: %v", err)
}
defer pool.Put(conn)
var execOpts sqlitex.ExecOptions var execOpts sqlitex.ExecOptions
rows := make([]any, 0, 16) rows := make([]any, 0, 16)
if state.GetTop() >= 3 && !state.IsNil(3) { if state.GetTop() >= 4 && !state.IsNil(3) {
if err := setupParams(state, 3, &execOpts); err != nil { if err := setupParams(state, 3, &execOpts); err != nil {
return state.PushError("sqlite.query: %v", err) return state.PushError("sqlite.query: %v", err)
} }
@ -178,7 +228,7 @@ func sqlQuery(state *luajit.State) int {
} }
func sqlExec(state *luajit.State) int { func sqlExec(state *luajit.State) int {
if err := state.CheckMinArgs(2); err != nil { if err := state.CheckMinArgs(3); err != nil {
return state.PushError("sqlite.exec: %v", err) return state.PushError("sqlite.exec: %v", err)
} }
@ -192,21 +242,14 @@ func sqlExec(state *luajit.State) int {
return state.PushError("sqlite.exec: query must be string") return state.PushError("sqlite.exec: query must be string")
} }
pool, err := getPool(dbName) stateIndex := int(state.ToNumber(-1))
conn, err := getStateConnection(stateIndex, dbName)
if err != nil { if err != nil {
return state.PushError("sqlite.exec: %v", err) return state.PushError("sqlite.exec: %v", err)
} }
ctx, cancel := context.WithTimeout(context.Background(), connTimeout) hasParams := state.GetTop() >= 4 && !state.IsNil(3)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
return state.PushError("sqlite.exec: connection timeout: %v", err)
}
defer pool.Put(conn)
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
if strings.Contains(query, ";") && !hasParams { if strings.Contains(query, ";") && !hasParams {
if err := sqlitex.ExecScript(conn, query); err != nil { if err := sqlitex.ExecScript(conn, query); err != nil {
@ -313,7 +356,7 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti
} }
func sqlGetOne(state *luajit.State) int { func sqlGetOne(state *luajit.State) int {
if err := state.CheckMinArgs(2); err != nil { if err := state.CheckMinArgs(3); err != nil {
return state.PushError("sqlite.get_one: %v", err) return state.PushError("sqlite.get_one: %v", err)
} }
@ -327,24 +370,18 @@ func sqlGetOne(state *luajit.State) int {
return state.PushError("sqlite.get_one: query must be string") return state.PushError("sqlite.get_one: query must be string")
} }
pool, err := getPool(dbName) stateIndex := int(state.ToNumber(-1))
conn, err := getStateConnection(stateIndex, dbName)
if err != nil { if err != nil {
return state.PushError("sqlite.get_one: %v", err) return state.PushError("sqlite.get_one: %v", err)
} }
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
return state.PushError("sqlite.get_one: connection timeout: %v", err)
}
defer pool.Put(conn)
var execOpts sqlitex.ExecOptions var execOpts sqlitex.ExecOptions
var result map[string]any var result map[string]any
if state.GetTop() >= 3 && !state.IsNil(3) { // Check if params provided (before state index)
if state.GetTop() >= 4 && !state.IsNil(3) {
if err := setupParams(state, 3, &execOpts); err != nil { if err := setupParams(state, 3, &execOpts); err != nil {
return state.PushError("sqlite.get_one: %v", err) return state.PushError("sqlite.get_one: %v", err)
} }
@ -397,6 +434,23 @@ func sqlGetOne(state *luajit.State) int {
return 1 return 1
} }
// CleanupStateConnection releases all connections for a specific state
func CleanupStateConnection(stateIndex int) {
stateConnsMu.Lock()
defer stateConnsMu.Unlock()
statePrefix := fmt.Sprintf("%d-", stateIndex)
for key, sc := range stateConns {
if strings.HasPrefix(key, statePrefix) {
if sc.pool != nil && sc.conn != nil {
sc.pool.Put(sc.conn)
}
delete(stateConns, key)
}
}
}
func RegisterSQLiteFunctions(state *luajit.State) error { func RegisterSQLiteFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil { if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
return err return err