diff --git a/runner/embed.go b/runner/embed.go index 44c509b..ea27ed9 100644 --- a/runner/embed.go +++ b/runner/embed.go @@ -67,20 +67,20 @@ type Module struct { var modules = []Module{ {"http", httpLuaCode, true}, + {"string", stringLuaCode, false}, + {"table", tableLuaCode, false}, + {"util", utilLuaCode, true}, {"cookie", cookieLuaCode, true}, {"session", sessionLuaCode, true}, {"csrf", csrfLuaCode, true}, {"render", renderLuaCode, true}, {"json", jsonLuaCode, true}, - {"sqlite", sqliteLuaCode, false}, {"fs", fsLuaCode, true}, - {"util", utilLuaCode, true}, - {"string", stringLuaCode, false}, - {"table", tableLuaCode, false}, {"crypto", cryptoLuaCode, true}, {"time", timeLuaCode, false}, {"math", mathLuaCode, false}, {"env", envLuaCode, true}, + {"sqlite", sqliteLuaCode, true}, {"timestamp", timestampLuaCode, false}, } diff --git a/runner/lua/sqlite.lua b/runner/lua/sqlite.lua index e859109..5d4096e 100644 --- a/runner/lua/sqlite.lua +++ b/runner/lua/sqlite.lua @@ -288,7 +288,7 @@ local connection_mt = { } } -return function(db_name) +function sqlite(db_name) if type(db_name) ~= "string" then error("sqlite: database name must be a string", 2) end diff --git a/runner/runner.go b/runner/runner.go index 80611c6..a7de247 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -249,7 +249,7 @@ func (r *Runner) createState(index int) (*State, error) { } sb := NewSandbox() - if err := sb.Setup(L, index == 0); err != nil { + if err := sb.Setup(L, index, index == 0); err != nil { L.Cleanup() L.Close() return nil, err diff --git a/runner/sandbox.go b/runner/sandbox.go index 5dcd76e..8c64663 100644 --- a/runner/sandbox.go +++ b/runner/sandbox.go @@ -22,12 +22,16 @@ func NewSandbox() *Sandbox { } // Setup initializes the sandbox in a Lua state -func (s *Sandbox) Setup(state *luajit.State, verbose bool) error { +func (s *Sandbox) Setup(state *luajit.State, stateIndex int, verbose bool) error { // Load all embedded modules and sandbox if err := loadSandboxIntoState(state, verbose); err != nil { return fmt.Errorf("failed to load sandbox: %w", err) } + // Set the state index as a global variable + state.PushNumber(float64(stateIndex)) + state.SetGlobal("__STATE_INDEX") + // Pre-compile the executor function for reuse executorCode := `return __execute` bytecode, err := state.CompileBytecode(executorCode, "executor") diff --git a/runner/sqlite/sqlite.go b/runner/sqlite/sqlite.go index e2887b7..d430b8a 100644 --- a/runner/sqlite/sqlite.go +++ b/runner/sqlite/sqlite.go @@ -18,13 +18,11 @@ import ( ) var ( - dbPools = make(map[string]*sqlitex.Pool) - poolsMu sync.RWMutex - dataDir string - poolSize = 8 - connTimeout = 5 * time.Second - stateConnections = make(map[*luajit.State]map[string]*sqlite.Conn) - stateConnsMu sync.RWMutex + dbPools = make(map[string]*sqlitex.Pool) + poolsMu sync.RWMutex + dataDir string + poolSize = 8 + connTimeout = 5 * time.Second ) func InitSQLite(dir string) { @@ -116,10 +114,19 @@ func sqlQuery(state *luajit.State) int { return state.PushError("sqlite.query: query must be string") } - conn, err := getStateConnection(state, dbName) + pool, err := getPool(dbName) + if err != nil { + return state.PushError("sqlite.query: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), connTimeout) + defer cancel() + + conn, err := pool.Take(ctx) if err != nil { return state.PushError("sqlite.query: connection timeout: %v", err) } + defer pool.Put(conn) var execOpts sqlitex.ExecOptions rows := make([]any, 0, 16) @@ -185,11 +192,20 @@ func sqlExec(state *luajit.State) int { return state.PushError("sqlite.exec: query must be string") } - conn, err := getStateConnection(state, dbName) + pool, err := getPool(dbName) if err != nil { - return state.PushError("sqlite.query: connection timeout: %v", err) + return state.PushError("sqlite.exec: %v", err) } + ctx, cancel := context.WithTimeout(context.Background(), connTimeout) + defer cancel() + + conn, err := pool.Take(ctx) + if err != nil { + return state.PushError("sqlite.exec: connection timeout: %v", err) + } + defer pool.Put(conn) + hasParams := state.GetTop() >= 3 && !state.IsNil(3) if strings.Contains(query, ";") && !hasParams { @@ -311,11 +327,20 @@ func sqlGetOne(state *luajit.State) int { return state.PushError("sqlite.get_one: query must be string") } - conn, err := getStateConnection(state, dbName) + pool, err := getPool(dbName) if err != nil { - return state.PushError("sqlite.query: connection timeout: %v", err) + return state.PushError("sqlite.get_one: %v", err) } + ctx, cancel := context.WithTimeout(context.Background(), connTimeout) + defer cancel() + + conn, err := pool.Take(ctx) + if err != nil { + return state.PushError("sqlite.get_one: connection timeout: %v", err) + } + defer pool.Put(conn) + var execOpts sqlitex.ExecOptions var result map[string]any @@ -384,52 +409,3 @@ func RegisterSQLiteFunctions(state *luajit.State) error { } return nil } - -func CleanupStateConnections(luaState *luajit.State) { - stateConnsMu.Lock() - defer stateConnsMu.Unlock() - - if conns, exists := stateConnections[luaState]; exists { - for dbName, conn := range conns { - if pool, err := getPool(dbName); err == nil { - pool.Put(conn) - } - } - delete(stateConnections, luaState) - } -} - -func getStateConnection(state *luajit.State, dbName string) (*sqlite.Conn, error) { - stateConnsMu.RLock() - if conns, exists := stateConnections[state]; exists { - if conn, exists := conns[dbName]; exists { - stateConnsMu.RUnlock() - return conn, nil - } - } - stateConnsMu.RUnlock() - - // Get new connection - pool, err := getPool(dbName) - if err != nil { - return nil, err - } - - ctx, cancel := context.WithTimeout(context.Background(), connTimeout) - defer cancel() - - conn, err := pool.Take(ctx) - if err != nil { - return nil, err - } - - // Cache it - stateConnsMu.Lock() - if stateConnections[state] == nil { - stateConnections[state] = make(map[string]*sqlite.Conn) - } - stateConnections[state][dbName] = conn - stateConnsMu.Unlock() - - return conn, nil -}