work towards reusable sqlite connections
This commit is contained in:
parent
0c4ddd7e3d
commit
4077ac03f1
@ -18,7 +18,7 @@ local connection_mt = {
|
||||
end
|
||||
|
||||
local normalized_params = normalize_params(params, ...)
|
||||
return __sqlite_query(self.db_name, query, normalized_params)
|
||||
return __sqlite_query(self.db_name, query, normalized_params, __STATE_INDEX)
|
||||
end,
|
||||
|
||||
exec = function(self, query, params, ...)
|
||||
@ -27,7 +27,7 @@ local connection_mt = {
|
||||
end
|
||||
|
||||
local normalized_params = normalize_params(params, ...)
|
||||
return __sqlite_exec(self.db_name, query, normalized_params)
|
||||
return __sqlite_exec(self.db_name, query, normalized_params, __STATE_INDEX)
|
||||
end,
|
||||
|
||||
get_one = function(self, query, params, ...)
|
||||
@ -36,7 +36,7 @@ local connection_mt = {
|
||||
end
|
||||
|
||||
local normalized_params = normalize_params(params, ...)
|
||||
return __sqlite_get_one(self.db_name, query, normalized_params)
|
||||
return __sqlite_get_one(self.db_name, query, normalized_params, __STATE_INDEX)
|
||||
end,
|
||||
|
||||
insert = function(self, table_name, data, columns)
|
||||
|
@ -122,7 +122,7 @@ func (r *Runner) ExecuteHTTP(bytecode []byte, httpCtx *fasthttp.RequestCtx,
|
||||
luaCtx := r.buildHTTPContext(httpCtx, params, session)
|
||||
defer r.releaseHTTPContext(luaCtx)
|
||||
|
||||
return state.sandbox.Execute(state.L, bytecode, luaCtx)
|
||||
return state.sandbox.Execute(state.L, bytecode, luaCtx, state.index)
|
||||
}
|
||||
|
||||
// Build Lua context from HTTP request
|
||||
@ -463,7 +463,7 @@ func (r *Runner) RunScriptFile(filePath string) (*Response, error) {
|
||||
ctx.Set("_script_dir", scriptDir)
|
||||
|
||||
// Execute script
|
||||
response, err := state.sandbox.Execute(state.L, bytecode, ctx)
|
||||
response, err := state.sandbox.Execute(state.L, bytecode, ctx, state.index)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("execution error: %w", err)
|
||||
}
|
||||
|
@ -49,7 +49,7 @@ func (s *Sandbox) Setup(state *luajit.State, stateIndex int, verbose bool) error
|
||||
}
|
||||
|
||||
// Execute runs a Lua script in the sandbox with the given context
|
||||
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
|
||||
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context, stateIndex int) (*Response, error) {
|
||||
// Load script and executor
|
||||
if err := state.LoadBytecode(bytecode, "script"); err != nil {
|
||||
return nil, fmt.Errorf("failed to load bytecode: %w", err)
|
||||
@ -90,6 +90,8 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*
|
||||
result, _ := state.ToValue(-1)
|
||||
state.Pop(2) // Clean up
|
||||
|
||||
sqlite.CleanupStateConnection(stateIndex)
|
||||
|
||||
var modifiedResponse map[string]any
|
||||
var scriptResult any
|
||||
|
||||
|
@ -23,8 +23,18 @@ var (
|
||||
dataDir string
|
||||
poolSize = 8
|
||||
connTimeout = 5 * time.Second
|
||||
|
||||
// Per-state connection cache
|
||||
stateConns = make(map[string]*stateConn)
|
||||
stateConnsMu sync.RWMutex
|
||||
)
|
||||
|
||||
// stateConn tracks a connection and its origin pool
|
||||
type stateConn struct {
|
||||
conn *sqlite.Conn
|
||||
pool *sqlitex.Pool
|
||||
}
|
||||
|
||||
func InitSQLite(dir string) {
|
||||
dataDir = dir
|
||||
logger.Infof("SQLite is g2g! %s", color.Yellow(dir))
|
||||
@ -40,6 +50,16 @@ func CleanupSQLite() {
|
||||
poolsMu.Lock()
|
||||
defer poolsMu.Unlock()
|
||||
|
||||
// Return all cached connections to their pools
|
||||
stateConnsMu.Lock()
|
||||
for _, sc := range stateConns {
|
||||
if sc.pool != nil && sc.conn != nil {
|
||||
sc.pool.Put(sc.conn)
|
||||
}
|
||||
}
|
||||
stateConns = make(map[string]*stateConn)
|
||||
stateConnsMu.Unlock()
|
||||
|
||||
for name, pool := range dbPools {
|
||||
if err := pool.Close(); err != nil {
|
||||
logger.Errorf("Failed to close database %s: %v", name, err)
|
||||
@ -99,8 +119,45 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// getStateConnection gets or creates a reusable connection for the state+db
|
||||
func getStateConnection(stateIndex int, dbName string) (*sqlite.Conn, error) {
|
||||
connKey := fmt.Sprintf("%d-%s", stateIndex, dbName)
|
||||
|
||||
stateConnsMu.RLock()
|
||||
sc, exists := stateConns[connKey]
|
||||
stateConnsMu.RUnlock()
|
||||
|
||||
if exists && sc.conn != nil {
|
||||
return sc.conn, nil
|
||||
}
|
||||
|
||||
// Get new connection from pool
|
||||
pool, err := getPool(dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pool.Take(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connection timeout: %w", err)
|
||||
}
|
||||
|
||||
// Cache it with pool reference
|
||||
stateConnsMu.Lock()
|
||||
stateConns[connKey] = &stateConn{
|
||||
conn: conn,
|
||||
pool: pool,
|
||||
}
|
||||
stateConnsMu.Unlock()
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func sqlQuery(state *luajit.State) int {
|
||||
if err := state.CheckMinArgs(2); err != nil {
|
||||
if err := state.CheckMinArgs(3); err != nil {
|
||||
return state.PushError("sqlite.query: %v", err)
|
||||
}
|
||||
|
||||
@ -114,24 +171,17 @@ func sqlQuery(state *luajit.State) int {
|
||||
return state.PushError("sqlite.query: query must be string")
|
||||
}
|
||||
|
||||
pool, err := getPool(dbName)
|
||||
stateIndex := int(state.ToNumber(-1))
|
||||
|
||||
conn, err := getStateConnection(stateIndex, dbName)
|
||||
if err != nil {
|
||||
return state.PushError("sqlite.query: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pool.Take(ctx)
|
||||
if err != nil {
|
||||
return state.PushError("sqlite.query: connection timeout: %v", err)
|
||||
}
|
||||
defer pool.Put(conn)
|
||||
|
||||
var execOpts sqlitex.ExecOptions
|
||||
rows := make([]any, 0, 16)
|
||||
|
||||
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||
if state.GetTop() >= 4 && !state.IsNil(3) {
|
||||
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||
return state.PushError("sqlite.query: %v", err)
|
||||
}
|
||||
@ -178,7 +228,7 @@ func sqlQuery(state *luajit.State) int {
|
||||
}
|
||||
|
||||
func sqlExec(state *luajit.State) int {
|
||||
if err := state.CheckMinArgs(2); err != nil {
|
||||
if err := state.CheckMinArgs(3); err != nil {
|
||||
return state.PushError("sqlite.exec: %v", err)
|
||||
}
|
||||
|
||||
@ -192,21 +242,14 @@ func sqlExec(state *luajit.State) int {
|
||||
return state.PushError("sqlite.exec: query must be string")
|
||||
}
|
||||
|
||||
pool, err := getPool(dbName)
|
||||
stateIndex := int(state.ToNumber(-1))
|
||||
|
||||
conn, err := getStateConnection(stateIndex, dbName)
|
||||
if err != nil {
|
||||
return state.PushError("sqlite.exec: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pool.Take(ctx)
|
||||
if err != nil {
|
||||
return state.PushError("sqlite.exec: connection timeout: %v", err)
|
||||
}
|
||||
defer pool.Put(conn)
|
||||
|
||||
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
|
||||
hasParams := state.GetTop() >= 4 && !state.IsNil(3)
|
||||
|
||||
if strings.Contains(query, ";") && !hasParams {
|
||||
if err := sqlitex.ExecScript(conn, query); err != nil {
|
||||
@ -313,7 +356,7 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti
|
||||
}
|
||||
|
||||
func sqlGetOne(state *luajit.State) int {
|
||||
if err := state.CheckMinArgs(2); err != nil {
|
||||
if err := state.CheckMinArgs(3); err != nil {
|
||||
return state.PushError("sqlite.get_one: %v", err)
|
||||
}
|
||||
|
||||
@ -327,24 +370,18 @@ func sqlGetOne(state *luajit.State) int {
|
||||
return state.PushError("sqlite.get_one: query must be string")
|
||||
}
|
||||
|
||||
pool, err := getPool(dbName)
|
||||
stateIndex := int(state.ToNumber(-1))
|
||||
|
||||
conn, err := getStateConnection(stateIndex, dbName)
|
||||
if err != nil {
|
||||
return state.PushError("sqlite.get_one: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pool.Take(ctx)
|
||||
if err != nil {
|
||||
return state.PushError("sqlite.get_one: connection timeout: %v", err)
|
||||
}
|
||||
defer pool.Put(conn)
|
||||
|
||||
var execOpts sqlitex.ExecOptions
|
||||
var result map[string]any
|
||||
|
||||
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||
// Check if params provided (before state index)
|
||||
if state.GetTop() >= 4 && !state.IsNil(3) {
|
||||
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||
return state.PushError("sqlite.get_one: %v", err)
|
||||
}
|
||||
@ -397,6 +434,23 @@ func sqlGetOne(state *luajit.State) int {
|
||||
return 1
|
||||
}
|
||||
|
||||
// CleanupStateConnection releases all connections for a specific state
|
||||
func CleanupStateConnection(stateIndex int) {
|
||||
stateConnsMu.Lock()
|
||||
defer stateConnsMu.Unlock()
|
||||
|
||||
statePrefix := fmt.Sprintf("%d-", stateIndex)
|
||||
|
||||
for key, sc := range stateConns {
|
||||
if strings.HasPrefix(key, statePrefix) {
|
||||
if sc.pool != nil && sc.conn != nil {
|
||||
sc.pool.Put(sc.conn)
|
||||
}
|
||||
delete(stateConns, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RegisterSQLiteFunctions(state *luajit.State) error {
|
||||
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
|
||||
return err
|
||||
|
Loading…
x
Reference in New Issue
Block a user