diff --git a/runner/lua/sqlite.lua b/runner/lua/sqlite.lua index 5d4096e..7b979b7 100644 --- a/runner/lua/sqlite.lua +++ b/runner/lua/sqlite.lua @@ -18,7 +18,7 @@ local connection_mt = { end local normalized_params = normalize_params(params, ...) - return __sqlite_query(self.db_name, query, normalized_params) + return __sqlite_query(self.db_name, query, normalized_params, __STATE_INDEX) end, exec = function(self, query, params, ...) @@ -27,7 +27,7 @@ local connection_mt = { end local normalized_params = normalize_params(params, ...) - return __sqlite_exec(self.db_name, query, normalized_params) + return __sqlite_exec(self.db_name, query, normalized_params, __STATE_INDEX) end, get_one = function(self, query, params, ...) @@ -36,7 +36,7 @@ local connection_mt = { end local normalized_params = normalize_params(params, ...) - return __sqlite_get_one(self.db_name, query, normalized_params) + return __sqlite_get_one(self.db_name, query, normalized_params, __STATE_INDEX) end, insert = function(self, table_name, data, columns) diff --git a/runner/runner.go b/runner/runner.go index a7de247..74d10b0 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -122,7 +122,7 @@ func (r *Runner) ExecuteHTTP(bytecode []byte, httpCtx *fasthttp.RequestCtx, luaCtx := r.buildHTTPContext(httpCtx, params, session) defer r.releaseHTTPContext(luaCtx) - return state.sandbox.Execute(state.L, bytecode, luaCtx) + return state.sandbox.Execute(state.L, bytecode, luaCtx, state.index) } // Build Lua context from HTTP request @@ -463,7 +463,7 @@ func (r *Runner) RunScriptFile(filePath string) (*Response, error) { ctx.Set("_script_dir", scriptDir) // Execute script - response, err := state.sandbox.Execute(state.L, bytecode, ctx) + response, err := state.sandbox.Execute(state.L, bytecode, ctx, state.index) if err != nil { return nil, fmt.Errorf("execution error: %w", err) } diff --git a/runner/sandbox.go b/runner/sandbox.go index 8c64663..61f91db 100644 --- a/runner/sandbox.go +++ b/runner/sandbox.go @@ -49,7 +49,7 @@ func (s *Sandbox) Setup(state *luajit.State, stateIndex int, verbose bool) error } // Execute runs a Lua script in the sandbox with the given context -func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) { +func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context, stateIndex int) (*Response, error) { // Load script and executor if err := state.LoadBytecode(bytecode, "script"); err != nil { return nil, fmt.Errorf("failed to load bytecode: %w", err) @@ -90,6 +90,8 @@ func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (* result, _ := state.ToValue(-1) state.Pop(2) // Clean up + sqlite.CleanupStateConnection(stateIndex) + var modifiedResponse map[string]any var scriptResult any diff --git a/runner/sqlite/sqlite.go b/runner/sqlite/sqlite.go index d430b8a..e35503a 100644 --- a/runner/sqlite/sqlite.go +++ b/runner/sqlite/sqlite.go @@ -23,8 +23,18 @@ var ( dataDir string poolSize = 8 connTimeout = 5 * time.Second + + // Per-state connection cache + stateConns = make(map[string]*stateConn) + stateConnsMu sync.RWMutex ) +// stateConn tracks a connection and its origin pool +type stateConn struct { + conn *sqlite.Conn + pool *sqlitex.Pool +} + func InitSQLite(dir string) { dataDir = dir logger.Infof("SQLite is g2g! %s", color.Yellow(dir)) @@ -40,6 +50,16 @@ func CleanupSQLite() { poolsMu.Lock() defer poolsMu.Unlock() + // Return all cached connections to their pools + stateConnsMu.Lock() + for _, sc := range stateConns { + if sc.pool != nil && sc.conn != nil { + sc.pool.Put(sc.conn) + } + } + stateConns = make(map[string]*stateConn) + stateConnsMu.Unlock() + for name, pool := range dbPools { if err := pool.Close(); err != nil { logger.Errorf("Failed to close database %s: %v", name, err) @@ -99,8 +119,45 @@ func getPool(dbName string) (*sqlitex.Pool, error) { return pool, nil } +// getStateConnection gets or creates a reusable connection for the state+db +func getStateConnection(stateIndex int, dbName string) (*sqlite.Conn, error) { + connKey := fmt.Sprintf("%d-%s", stateIndex, dbName) + + stateConnsMu.RLock() + sc, exists := stateConns[connKey] + stateConnsMu.RUnlock() + + if exists && sc.conn != nil { + return sc.conn, nil + } + + // Get new connection from pool + pool, err := getPool(dbName) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), connTimeout) + defer cancel() + + conn, err := pool.Take(ctx) + if err != nil { + return nil, fmt.Errorf("connection timeout: %w", err) + } + + // Cache it with pool reference + stateConnsMu.Lock() + stateConns[connKey] = &stateConn{ + conn: conn, + pool: pool, + } + stateConnsMu.Unlock() + + return conn, nil +} + func sqlQuery(state *luajit.State) int { - if err := state.CheckMinArgs(2); err != nil { + if err := state.CheckMinArgs(3); err != nil { return state.PushError("sqlite.query: %v", err) } @@ -114,24 +171,17 @@ func sqlQuery(state *luajit.State) int { return state.PushError("sqlite.query: query must be string") } - pool, err := getPool(dbName) + stateIndex := int(state.ToNumber(-1)) + + conn, err := getStateConnection(stateIndex, dbName) if err != nil { return state.PushError("sqlite.query: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), connTimeout) - defer cancel() - - conn, err := pool.Take(ctx) - if err != nil { - return state.PushError("sqlite.query: connection timeout: %v", err) - } - defer pool.Put(conn) - var execOpts sqlitex.ExecOptions rows := make([]any, 0, 16) - if state.GetTop() >= 3 && !state.IsNil(3) { + if state.GetTop() >= 4 && !state.IsNil(3) { if err := setupParams(state, 3, &execOpts); err != nil { return state.PushError("sqlite.query: %v", err) } @@ -178,7 +228,7 @@ func sqlQuery(state *luajit.State) int { } func sqlExec(state *luajit.State) int { - if err := state.CheckMinArgs(2); err != nil { + if err := state.CheckMinArgs(3); err != nil { return state.PushError("sqlite.exec: %v", err) } @@ -192,21 +242,14 @@ func sqlExec(state *luajit.State) int { return state.PushError("sqlite.exec: query must be string") } - pool, err := getPool(dbName) + stateIndex := int(state.ToNumber(-1)) + + conn, err := getStateConnection(stateIndex, dbName) if err != nil { return state.PushError("sqlite.exec: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), connTimeout) - defer cancel() - - conn, err := pool.Take(ctx) - if err != nil { - return state.PushError("sqlite.exec: connection timeout: %v", err) - } - defer pool.Put(conn) - - hasParams := state.GetTop() >= 3 && !state.IsNil(3) + hasParams := state.GetTop() >= 4 && !state.IsNil(3) if strings.Contains(query, ";") && !hasParams { if err := sqlitex.ExecScript(conn, query); err != nil { @@ -313,7 +356,7 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti } func sqlGetOne(state *luajit.State) int { - if err := state.CheckMinArgs(2); err != nil { + if err := state.CheckMinArgs(3); err != nil { return state.PushError("sqlite.get_one: %v", err) } @@ -327,24 +370,18 @@ func sqlGetOne(state *luajit.State) int { return state.PushError("sqlite.get_one: query must be string") } - pool, err := getPool(dbName) + stateIndex := int(state.ToNumber(-1)) + + conn, err := getStateConnection(stateIndex, dbName) if err != nil { return state.PushError("sqlite.get_one: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), connTimeout) - defer cancel() - - conn, err := pool.Take(ctx) - if err != nil { - return state.PushError("sqlite.get_one: connection timeout: %v", err) - } - defer pool.Put(conn) - var execOpts sqlitex.ExecOptions var result map[string]any - if state.GetTop() >= 3 && !state.IsNil(3) { + // Check if params provided (before state index) + if state.GetTop() >= 4 && !state.IsNil(3) { if err := setupParams(state, 3, &execOpts); err != nil { return state.PushError("sqlite.get_one: %v", err) } @@ -397,6 +434,23 @@ func sqlGetOne(state *luajit.State) int { return 1 } +// CleanupStateConnection releases all connections for a specific state +func CleanupStateConnection(stateIndex int) { + stateConnsMu.Lock() + defer stateConnsMu.Unlock() + + statePrefix := fmt.Sprintf("%d-", stateIndex) + + for key, sc := range stateConns { + if strings.HasPrefix(key, statePrefix) { + if sc.pool != nil && sc.conn != nil { + sc.pool.Put(sc.conn) + } + delete(stateConns, key) + } + } +} + func RegisterSQLiteFunctions(state *luajit.State) error { if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil { return err