revert sqlite state tracking, add state index as global

This commit is contained in:
Sky Johnson 2025-06-05 14:52:50 -05:00
parent 14fcd7894b
commit 0c4ddd7e3d
5 changed files with 48 additions and 68 deletions

View File

@ -67,20 +67,20 @@ type Module struct {
var modules = []Module{ var modules = []Module{
{"http", httpLuaCode, true}, {"http", httpLuaCode, true},
{"string", stringLuaCode, false},
{"table", tableLuaCode, false},
{"util", utilLuaCode, true},
{"cookie", cookieLuaCode, true}, {"cookie", cookieLuaCode, true},
{"session", sessionLuaCode, true}, {"session", sessionLuaCode, true},
{"csrf", csrfLuaCode, true}, {"csrf", csrfLuaCode, true},
{"render", renderLuaCode, true}, {"render", renderLuaCode, true},
{"json", jsonLuaCode, true}, {"json", jsonLuaCode, true},
{"sqlite", sqliteLuaCode, false},
{"fs", fsLuaCode, true}, {"fs", fsLuaCode, true},
{"util", utilLuaCode, true},
{"string", stringLuaCode, false},
{"table", tableLuaCode, false},
{"crypto", cryptoLuaCode, true}, {"crypto", cryptoLuaCode, true},
{"time", timeLuaCode, false}, {"time", timeLuaCode, false},
{"math", mathLuaCode, false}, {"math", mathLuaCode, false},
{"env", envLuaCode, true}, {"env", envLuaCode, true},
{"sqlite", sqliteLuaCode, true},
{"timestamp", timestampLuaCode, false}, {"timestamp", timestampLuaCode, false},
} }

View File

@ -288,7 +288,7 @@ local connection_mt = {
} }
} }
return function(db_name) function sqlite(db_name)
if type(db_name) ~= "string" then if type(db_name) ~= "string" then
error("sqlite: database name must be a string", 2) error("sqlite: database name must be a string", 2)
end end

View File

@ -249,7 +249,7 @@ func (r *Runner) createState(index int) (*State, error) {
} }
sb := NewSandbox() sb := NewSandbox()
if err := sb.Setup(L, index == 0); err != nil { if err := sb.Setup(L, index, index == 0); err != nil {
L.Cleanup() L.Cleanup()
L.Close() L.Close()
return nil, err return nil, err

View File

@ -22,12 +22,16 @@ func NewSandbox() *Sandbox {
} }
// Setup initializes the sandbox in a Lua state // Setup initializes the sandbox in a Lua state
func (s *Sandbox) Setup(state *luajit.State, verbose bool) error { func (s *Sandbox) Setup(state *luajit.State, stateIndex int, verbose bool) error {
// Load all embedded modules and sandbox // Load all embedded modules and sandbox
if err := loadSandboxIntoState(state, verbose); err != nil { if err := loadSandboxIntoState(state, verbose); err != nil {
return fmt.Errorf("failed to load sandbox: %w", err) return fmt.Errorf("failed to load sandbox: %w", err)
} }
// Set the state index as a global variable
state.PushNumber(float64(stateIndex))
state.SetGlobal("__STATE_INDEX")
// Pre-compile the executor function for reuse // Pre-compile the executor function for reuse
executorCode := `return __execute` executorCode := `return __execute`
bytecode, err := state.CompileBytecode(executorCode, "executor") bytecode, err := state.CompileBytecode(executorCode, "executor")

View File

@ -18,13 +18,11 @@ import (
) )
var ( var (
dbPools = make(map[string]*sqlitex.Pool) dbPools = make(map[string]*sqlitex.Pool)
poolsMu sync.RWMutex poolsMu sync.RWMutex
dataDir string dataDir string
poolSize = 8 poolSize = 8
connTimeout = 5 * time.Second connTimeout = 5 * time.Second
stateConnections = make(map[*luajit.State]map[string]*sqlite.Conn)
stateConnsMu sync.RWMutex
) )
func InitSQLite(dir string) { func InitSQLite(dir string) {
@ -116,10 +114,19 @@ func sqlQuery(state *luajit.State) int {
return state.PushError("sqlite.query: query must be string") return state.PushError("sqlite.query: query must be string")
} }
conn, err := getStateConnection(state, dbName) pool, err := getPool(dbName)
if err != nil {
return state.PushError("sqlite.query: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil { if err != nil {
return state.PushError("sqlite.query: connection timeout: %v", err) return state.PushError("sqlite.query: connection timeout: %v", err)
} }
defer pool.Put(conn)
var execOpts sqlitex.ExecOptions var execOpts sqlitex.ExecOptions
rows := make([]any, 0, 16) rows := make([]any, 0, 16)
@ -185,11 +192,20 @@ func sqlExec(state *luajit.State) int {
return state.PushError("sqlite.exec: query must be string") return state.PushError("sqlite.exec: query must be string")
} }
conn, err := getStateConnection(state, dbName) pool, err := getPool(dbName)
if err != nil { if err != nil {
return state.PushError("sqlite.query: connection timeout: %v", err) return state.PushError("sqlite.exec: %v", err)
} }
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
return state.PushError("sqlite.exec: connection timeout: %v", err)
}
defer pool.Put(conn)
hasParams := state.GetTop() >= 3 && !state.IsNil(3) hasParams := state.GetTop() >= 3 && !state.IsNil(3)
if strings.Contains(query, ";") && !hasParams { if strings.Contains(query, ";") && !hasParams {
@ -311,11 +327,20 @@ func sqlGetOne(state *luajit.State) int {
return state.PushError("sqlite.get_one: query must be string") return state.PushError("sqlite.get_one: query must be string")
} }
conn, err := getStateConnection(state, dbName) pool, err := getPool(dbName)
if err != nil { if err != nil {
return state.PushError("sqlite.query: connection timeout: %v", err) return state.PushError("sqlite.get_one: %v", err)
} }
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
return state.PushError("sqlite.get_one: connection timeout: %v", err)
}
defer pool.Put(conn)
var execOpts sqlitex.ExecOptions var execOpts sqlitex.ExecOptions
var result map[string]any var result map[string]any
@ -384,52 +409,3 @@ func RegisterSQLiteFunctions(state *luajit.State) error {
} }
return nil return nil
} }
func CleanupStateConnections(luaState *luajit.State) {
stateConnsMu.Lock()
defer stateConnsMu.Unlock()
if conns, exists := stateConnections[luaState]; exists {
for dbName, conn := range conns {
if pool, err := getPool(dbName); err == nil {
pool.Put(conn)
}
}
delete(stateConnections, luaState)
}
}
func getStateConnection(state *luajit.State, dbName string) (*sqlite.Conn, error) {
stateConnsMu.RLock()
if conns, exists := stateConnections[state]; exists {
if conn, exists := conns[dbName]; exists {
stateConnsMu.RUnlock()
return conn, nil
}
}
stateConnsMu.RUnlock()
// Get new connection
pool, err := getPool(dbName)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
return nil, err
}
// Cache it
stateConnsMu.Lock()
if stateConnections[state] == nil {
stateConnections[state] = make(map[string]*sqlite.Conn)
}
stateConnections[state][dbName] = conn
stateConnsMu.Unlock()
return conn, nil
}