From 266da9fd239f57115d089cc98a667605dd1495a2 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Fri, 30 May 2025 13:24:58 -0500 Subject: [PATCH] major sqlite bug fix, minor lua state closing fix, add headers to lua ctx --- http/server.go | 9 +- runner/lua/sqlite.lua | 37 +--- runner/runner.go | 101 +++++++--- runner/sqlite.go | 424 ++++++++++++++++++------------------------ 4 files changed, 268 insertions(+), 303 deletions(-) diff --git a/http/server.go b/http/server.go index 424e1a3..cc9d84c 100644 --- a/http/server.go +++ b/http/server.go @@ -210,10 +210,17 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip luaCtx.Set("host", string(ctx.Host())) luaCtx.Set("session", sessionMap) + // Add headers to context + headers := make(map[string]any) + ctx.Request.Header.VisitAll(func(key, value []byte) { + headers[string(key)] = string(value) + }) + luaCtx.Set("headers", headers) + // Handle params if params != nil && params.Count > 0 { paramMap := s.paramsPool.Get().(map[string]any) - for i := 0; i < params.Count; i++ { + for i := range params.Count { paramMap[params.Keys[i]] = params.Values[i] } luaCtx.Set("params", paramMap) diff --git a/runner/lua/sqlite.lua b/runner/lua/sqlite.lua index a5a66fc..5b042dd 100644 --- a/runner/lua/sqlite.lua +++ b/runner/lua/sqlite.lua @@ -16,9 +16,7 @@ local connection_mt = { end local normalized_params = normalize_params(params, ...) - local results, token = __sqlite_query(self.db_name, query, normalized_params, self.conn_token) - self.conn_token = token - return results + return __sqlite_query(self.db_name, query, normalized_params) end, exec = function(self, query, params, ...) @@ -27,18 +25,16 @@ local connection_mt = { end local normalized_params = normalize_params(params, ...) - local affected, token = __sqlite_exec(self.db_name, query, normalized_params, self.conn_token) - self.conn_token = token - return affected + return __sqlite_exec(self.db_name, query, normalized_params) end, - close = function(self) - if self.conn_token then - local success = __sqlite_close(self.conn_token) - self.conn_token = nil - return success + get_one = function(self, query, params, ...) + if type(query) ~= "string" then + error("connection:get_one: query must be a string", 2) end - return false + + local normalized_params = normalize_params(params, ...) + return __sqlite_get_one(self.db_name, query, normalized_params) end, insert = function(self, table_name, data, columns) @@ -249,20 +245,6 @@ local connection_mt = { return self:exec(query, normalize_params(params, ...)) end, - get_one = function(self, query, params, ...) - if type(query) ~= "string" then - error("connection:get_one: query must be a string", 2) - end - - local limited_query = query - if not string.contains(query:lower(), "limit") then - limited_query = query .. " LIMIT 1" - end - - local results = self:query(limited_query, normalize_params(params, ...)) - return results[1] - end, - exists = function(self, table_name, where, params, ...) if type(table_name) ~= "string" then error("connection:exists: table_name must be a string", 2) @@ -310,7 +292,6 @@ return function(db_name) end return setmetatable({ - db_name = db_name, - conn_token = nil + db_name = db_name }, connection_mt) end diff --git a/runner/runner.go b/runner/runner.go index 3011338..8c9267a 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -33,7 +33,7 @@ type State struct { L *luajit.State // The Lua state sandbox *Sandbox // Associated sandbox index int // Index for debugging - inUse bool // Whether the state is currently in use + inUse atomic.Bool // Whether the state is currently in use } // Runner runs Lua scripts using a pool of Lua states @@ -115,14 +115,16 @@ func NewRunner(options ...RunnerOption) (*Runner, error) { InitSQLite(runner.dataDir) InitFS(runner.fsDir) + SetSQLitePoolSize(runner.poolSize) + // Initialize states and pool runner.states = make([]*State, runner.poolSize) runner.statePool = make(chan int, runner.poolSize) // Create and initialize all states if err := runner.initializeStates(); err != nil { - CleanupSQLite() // Clean up SQLite connections - runner.Close() // Clean up already created states + CleanupSQLite() + runner.Close() return nil, err } @@ -190,7 +192,6 @@ func (r *Runner) createState(index int) (*State, error) { L: L, sandbox: sb, index: index, - inUse: false, }, nil } @@ -215,29 +216,26 @@ func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context, // Got a state case <-ctx.Done(): return nil, ctx.Err() - case <-time.After(5 * time.Second): + case <-time.After(1 * time.Second): return nil, ErrTimeout } - // Get the actual state state := r.states[stateIndex] if state == nil { r.statePool <- stateIndex return nil, ErrStateNotReady } - // Mark state as in use - state.inUse = true + // Use atomic operations + state.inUse.Store(true) - // Ensure state is returned to pool when done defer func() { - state.inUse = false + state.inUse.Store(false) if r.isRunning.Load() { select { case r.statePool <- stateIndex: - // State returned to pool default: - // Pool is full or closed + // Pool is full or closed, state will be cleaned up by Close() } } }() @@ -267,21 +265,45 @@ func (r *Runner) Close() error { r.isRunning.Store(false) - // Drain the state pool + // Drain all states from the pool for { select { case <-r.statePool: - // Drain one state default: - // Pool is empty - goto cleanup + goto waitForInUse } } -cleanup: - // Clean up all states +waitForInUse: + // Wait for in-use states to finish (with timeout) + timeout := time.Now().Add(10 * time.Second) + for { + allIdle := true + for _, state := range r.states { + if state != nil && state.inUse.Load() { + allIdle = false + break + } + } + + if allIdle { + break + } + + if time.Now().After(timeout) { + logger.Warning("Timeout waiting for states to finish during shutdown, forcing close") + break + } + + time.Sleep(10 * time.Millisecond) + } + + // Now safely close all states for i, state := range r.states { if state != nil { + if state.inUse.Load() { + logger.Warning("Force closing state %d that is still in use", i) + } state.L.Cleanup() state.L.Close() r.states[i] = nil @@ -310,19 +332,40 @@ func (r *Runner) RefreshStates() error { for { select { case <-r.statePool: - // Drain one state default: - // Pool is empty - goto cleanup + goto waitForInUse } } -cleanup: - // Destroy all existing states +waitForInUse: + // Wait for in-use states to finish (with timeout) + timeout := time.Now().Add(10 * time.Second) + for { + allIdle := true + for _, state := range r.states { + if state != nil && state.inUse.Load() { + allIdle = false + break + } + } + + if allIdle { + break + } + + if time.Now().After(timeout) { + logger.Warning("Timeout waiting for states to finish, forcing refresh") + break + } + + time.Sleep(10 * time.Millisecond) + } + + // Now safely destroy all states for i, state := range r.states { if state != nil { - if state.inUse { - logger.Warning("Attempting to refresh state %d that is in use", i) + if state.inUse.Load() { + logger.Warning("Force closing state %d that is still in use", i) } state.L.Cleanup() state.L.Close() @@ -367,7 +410,7 @@ func (r *Runner) RefreshModule(moduleName string) bool { success := true for _, state := range r.states { - if state == nil || state.inUse { + if state == nil || state.inUse.Load() { continue } @@ -403,7 +446,7 @@ func (r *Runner) GetActiveStateCount() int { count := 0 for _, state := range r.states { - if state != nil && state.inUse { + if state != nil && state.inUse.Load() { count++ } } @@ -459,10 +502,10 @@ func (r *Runner) RunScriptFile(filePath string) (*Response, error) { return nil, ErrStateNotReady } - state.inUse = true + state.inUse.Store(true) defer func() { - state.inUse = false + state.inUse.Store(false) if r.isRunning.Load() { select { case r.statePool <- stateIndex: diff --git a/runner/sqlite.go b/runner/sqlite.go index 44e3bb5..c6cb82b 100644 --- a/runner/sqlite.go +++ b/runner/sqlite.go @@ -2,8 +2,6 @@ package runner import ( "context" - "crypto/rand" - "encoding/base64" "fmt" "path/filepath" "strings" @@ -19,71 +17,29 @@ import ( luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) -// DbPools maintains database connection pools var ( - dbPools = make(map[string]*sqlitex.Pool) - poolsMu sync.RWMutex - dataDir string - - // Connection tracking - activeConns = make(map[string]*TrackedConn) - activeConnMu sync.RWMutex - connTimeout = 5 * time.Minute + dbPools = make(map[string]*sqlitex.Pool) + poolsMu sync.RWMutex + dataDir string + poolSize = 8 // Default, will be set to match runner pool size + connTimeout = 5 * time.Second ) -// TrackedConn holds a connection with usage tracking -type TrackedConn struct { - Conn *sqlite.Conn - Pool *sqlitex.Pool - DBName string - LastUsed time.Time -} - -// generateConnToken creates a unique token for connection tracking -func generateConnToken() string { - b := make([]byte, 8) - rand.Read(b) - return base64.URLEncoding.EncodeToString(b) -} - // InitSQLite initializes the SQLite subsystem func InitSQLite(dir string) { dataDir = dir logger.Info("SQLite is g2g! %s", color.Apply(dir, color.Yellow)) - - // Start connection cleanup goroutine - go cleanupIdleConnections() } -// cleanupIdleConnections periodically checks for and removes idle connections -func cleanupIdleConnections() { - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - - for range ticker.C { - now := time.Now() - - activeConnMu.Lock() - for token, conn := range activeConns { - if conn.LastUsed.Add(connTimeout).Before(now) { - logger.Debug("Closing idle connection: %s (%s)", token, conn.DBName) - conn.Pool.Put(conn.Conn) - delete(activeConns, token) - } - } - activeConnMu.Unlock() +// SetSQLitePoolSize sets the pool size to match the runner pool size +func SetSQLitePoolSize(size int) { + if size > 0 { + poolSize = size } } // CleanupSQLite closes all database connections func CleanupSQLite() { - activeConnMu.Lock() - for token, conn := range activeConns { - conn.Pool.Put(conn.Conn) - delete(activeConns, token) - } - activeConnMu.Unlock() - poolsMu.Lock() defer poolsMu.Unlock() @@ -123,74 +79,36 @@ func getPool(dbName string) (*sqlitex.Pool, error) { return pool, nil } - // Create new pool + // Create new pool with proper size dbPath := filepath.Join(dataDir, dbName+".db") - pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{}) + pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{ + PoolSize: poolSize, + PrepareConn: func(conn *sqlite.Conn) error { + // Execute PRAGMA statements individually + pragmas := []string{ + "PRAGMA journal_mode = WAL", + "PRAGMA synchronous = NORMAL", + "PRAGMA cache_size = 1000", + "PRAGMA foreign_keys = ON", + "PRAGMA temp_store = MEMORY", + } + for _, pragma := range pragmas { + if err := sqlitex.ExecuteTransient(conn, pragma, nil); err != nil { + return err + } + } + return nil + }, + }) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } dbPools[dbName] = pool + logger.Debug("Created SQLite pool for %s (size: %d)", dbName, poolSize) return pool, nil } -// getConnection retrieves or creates a tracked connection -func getConnection(token, dbName string) (*TrackedConn, string, error) { - // If token is provided, try to get existing connection - if token != "" { - activeConnMu.RLock() - conn, exists := activeConns[token] - activeConnMu.RUnlock() - - if exists { - conn.LastUsed = time.Now() - return conn, token, nil - } - } - - // Token not provided or connection not found, create new - pool, err := getPool(dbName) - if err != nil { - return nil, "", err - } - - conn, err := pool.Take(context.Background()) - if err != nil { - return nil, "", err - } - - // Generate new token - newToken := generateConnToken() - - trackedConn := &TrackedConn{ - Conn: conn, - Pool: pool, - DBName: dbName, - LastUsed: time.Now(), - } - - activeConnMu.Lock() - activeConns[newToken] = trackedConn - activeConnMu.Unlock() - - return trackedConn, newToken, nil -} - -// releaseConnection releases a connection back to the pool -func releaseConnection(token string) bool { - activeConnMu.Lock() - defer activeConnMu.Unlock() - - conn, exists := activeConns[token] - if !exists { - return false - } - - conn.Pool.Put(conn.Conn) - delete(activeConns, token) - return true -} - // sqlQuery executes a SQL query and returns results func sqlQuery(state *luajit.State) int { // Get required parameters @@ -202,20 +120,23 @@ func sqlQuery(state *luajit.State) int { dbName := state.ToString(1) query := state.ToString(2) - // Get connection token (optional) - var connToken string - if state.GetTop() >= 4 && state.IsString(4) { - connToken = state.ToString(4) - } - - // Get connection - trackedConn, newToken, err := getConnection(connToken, dbName) + // Get pool + pool, err := getPool(dbName) if err != nil { state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) return -1 } - conn := trackedConn.Conn + // Get connection with timeout + ctx, cancel := context.WithTimeout(context.Background(), connTimeout) + defer cancel() + + conn, err := pool.Take(ctx) + if err != nil { + state.PushString(fmt.Sprintf("sqlite.query: connection timeout: %s", err.Error())) + return -1 + } + defer pool.Put(conn) // Create execution options var execOpts sqlitex.ExecOptions @@ -223,64 +144,9 @@ func sqlQuery(state *luajit.State) int { // Set up parameters if provided if state.GetTop() >= 3 && !state.IsNil(3) { - if state.IsTable(3) { - params, err := state.ToTable(3) - if err != nil { - state.PushString(fmt.Sprintf("sqlite.query: invalid parameters: %s", err.Error())) - return -1 - } - - // Check for array-style params - if arr, ok := params[""]; ok { - if arrParams, ok := arr.([]any); ok { - execOpts.Args = arrParams - } else if floatArr, ok := arr.([]float64); ok { - args := make([]any, len(floatArr)) - for i, v := range floatArr { - args[i] = v - } - execOpts.Args = args - } - } else { - // Named parameters - named := make(map[string]any, len(params)) - for k, v := range params { - if len(k) > 0 && k[0] != ':' { - named[":"+k] = v - } else { - named[k] = v - } - } - execOpts.Named = named - } - } else { - // Positional parameters - count := state.GetTop() - 2 - if state.IsString(4) { - count-- // Don't include connection token - } - args := make([]any, count) - for i := range count { - idx := i + 3 - switch state.GetType(idx) { - case luajit.TypeNumber: - args[i] = state.ToNumber(idx) - case luajit.TypeString: - args[i] = state.ToString(idx) - case luajit.TypeBoolean: - args[i] = state.ToBoolean(idx) - case luajit.TypeNil: - args[i] = nil - default: - val, err := state.ToValue(idx) - if err != nil { - state.PushString(fmt.Sprintf("sqlite.query: invalid parameter %d: %s", i+1, err.Error())) - return -1 - } - args[i] = val - } - } - execOpts.Args = args + if err := setupParams(state, 3, &execOpts); err != nil { + state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) + return -1 } } @@ -300,8 +166,12 @@ func sqlQuery(state *luajit.State) int { row[colName] = stmt.ColumnText(i) case sqlite.TypeBlob: blobSize := stmt.ColumnLen(i) - buf := make([]byte, blobSize) - row[colName] = stmt.ColumnBytes(i, buf) + if blobSize > 0 { + buf := make([]byte, blobSize) + row[colName] = stmt.ColumnBytes(i, buf) + } else { + row[colName] = []byte{} + } case sqlite.TypeNull: row[colName] = nil } @@ -327,10 +197,7 @@ func sqlQuery(state *luajit.State) int { state.SetTable(-3) } - // Return connection token - state.PushString(newToken) - - return 2 + return 1 } // sqlExec executes a SQL statement without returning results @@ -344,56 +211,71 @@ func sqlExec(state *luajit.State) int { dbName := state.ToString(1) query := state.ToString(2) - // Get connection token (optional) - var connToken string - if state.GetTop() >= 4 && state.IsString(4) { - connToken = state.ToString(4) - } - - // Get connection - trackedConn, newToken, err := getConnection(connToken, dbName) + // Get pool + pool, err := getPool(dbName) if err != nil { - state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) + state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) return -1 } - conn := trackedConn.Conn + // Get connection with timeout + ctx, cancel := context.WithTimeout(context.Background(), connTimeout) + defer cancel() + + conn, err := pool.Take(ctx) + if err != nil { + state.PushString(fmt.Sprintf("sqlite.exec: connection timeout: %s", err.Error())) + return -1 + } + defer pool.Put(conn) // Check if parameters are provided hasParams := state.GetTop() >= 3 && !state.IsNil(3) - // Fast path for multi-statement scripts - use ExecScript + // Fast path for multi-statement scripts if strings.Contains(query, ";") && !hasParams { if err := sqlitex.ExecScript(conn, query); err != nil { state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) return -1 } state.PushNumber(float64(conn.Changes())) - state.PushString(newToken) - return 2 + return 1 } // Fast path for simple queries with no parameters if !hasParams { - // Use Execute for simple statements without parameters if err := sqlitex.Execute(conn, query, nil); err != nil { state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) return -1 } state.PushNumber(float64(conn.Changes())) - state.PushString(newToken) - return 2 + return 1 } // Create execution options for parameterized query var execOpts sqlitex.ExecOptions + if err := setupParams(state, 3, &execOpts); err != nil { + state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) + return -1 + } - // Set up parameters - if state.IsTable(3) { - params, err := state.ToTable(3) + // Execute with parameters + if err := sqlitex.Execute(conn, query, &execOpts); err != nil { + state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) + return -1 + } + + // Return affected rows + state.PushNumber(float64(conn.Changes())) + return 1 +} + +// setupParams configures execution options with parameters from Lua +func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error { + if state.IsTable(paramIndex) { + params, err := state.ToTable(paramIndex) if err != nil { - state.PushString(fmt.Sprintf("sqlite.exec: invalid parameters: %s", err.Error())) - return -1 + return fmt.Errorf("invalid parameters: %w", err) } // Check for array-style params @@ -420,59 +302,111 @@ func sqlExec(state *luajit.State) int { execOpts.Named = named } } else { - // Positional parameters + // Positional parameters from stack count := state.GetTop() - 2 - if state.IsString(4) { - count-- // Don't include connection token - } args := make([]any, count) for i := range count { idx := i + 3 - switch state.GetType(idx) { - case luajit.TypeNumber: - args[i] = state.ToNumber(idx) - case luajit.TypeString: - args[i] = state.ToString(idx) - case luajit.TypeBoolean: - args[i] = state.ToBoolean(idx) - case luajit.TypeNil: - args[i] = nil - default: - val, err := state.ToValue(idx) - if err != nil { - state.PushString(fmt.Sprintf("sqlite.exec: invalid parameter %d: %s", i+1, err.Error())) - return -1 - } - args[i] = val + val, err := state.ToValue(idx) + if err != nil { + return fmt.Errorf("invalid parameter %d: %w", i+1, err) } + args[i] = val } execOpts.Args = args } - // Execute with parameters - if err := sqlitex.Execute(conn, query, &execOpts); err != nil { - state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) - return -1 - } - - // Return affected rows and connection token - state.PushNumber(float64(conn.Changes())) - state.PushString(newToken) - return 2 + return nil } -// sqlClose releases a connection back to the pool -func sqlClose(state *luajit.State) int { - if state.GetTop() < 1 || !state.IsString(1) { - state.PushString("sqlite.close: requires connection token") +// sqlGetOne executes a query and returns only the first row +func sqlGetOne(state *luajit.State) int { + // Get required parameters + if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) { + state.PushString("sqlite.get_one: requires database name and query") return -1 } - token := state.ToString(1) - if releaseConnection(token) { - state.PushBoolean(true) + dbName := state.ToString(1) + query := state.ToString(2) + + // Get pool + pool, err := getPool(dbName) + if err != nil { + state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error())) + return -1 + } + + // Get connection with timeout + ctx, cancel := context.WithTimeout(context.Background(), connTimeout) + defer cancel() + + conn, err := pool.Take(ctx) + if err != nil { + state.PushString(fmt.Sprintf("sqlite.get_one: connection timeout: %s", err.Error())) + return -1 + } + defer pool.Put(conn) + + // Create execution options + var execOpts sqlitex.ExecOptions + var result map[string]any + + // Set up parameters if provided + if state.GetTop() >= 3 && !state.IsNil(3) { + if err := setupParams(state, 3, &execOpts); err != nil { + state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error())) + return -1 + } + } + + // Set up result function to get only first row + execOpts.ResultFunc = func(stmt *sqlite.Stmt) error { + if result != nil { + return nil // Already got first row + } + + result = make(map[string]any) + colCount := stmt.ColumnCount() + + for i := range colCount { + colName := stmt.ColumnName(i) + switch stmt.ColumnType(i) { + case sqlite.TypeInteger: + result[colName] = stmt.ColumnInt64(i) + case sqlite.TypeFloat: + result[colName] = stmt.ColumnFloat(i) + case sqlite.TypeText: + result[colName] = stmt.ColumnText(i) + case sqlite.TypeBlob: + blobSize := stmt.ColumnLen(i) + if blobSize > 0 { + buf := make([]byte, blobSize) + result[colName] = stmt.ColumnBytes(i, buf) + } else { + result[colName] = []byte{} + } + case sqlite.TypeNull: + result[colName] = nil + } + } + return nil + } + + // Execute query + if err := sqlitex.Execute(conn, query, &execOpts); err != nil { + state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error())) + return -1 + } + + // Return result or nil if no rows + if result == nil { + state.PushNil() } else { - state.PushBoolean(false) + if err := state.PushTable(result); err != nil { + state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error())) + return -1 + } } return 1 @@ -486,7 +420,7 @@ func RegisterSQLiteFunctions(state *luajit.State) error { if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil { return err } - if err := state.RegisterGoFunction("__sqlite_close", sqlClose); err != nil { + if err := state.RegisterGoFunction("__sqlite_get_one", sqlGetOne); err != nil { return err } return nil