revert sqlite state tracking, add state index as global
This commit is contained in:
parent
14fcd7894b
commit
0c4ddd7e3d
@ -67,20 +67,20 @@ type Module struct {
|
|||||||
|
|
||||||
var modules = []Module{
|
var modules = []Module{
|
||||||
{"http", httpLuaCode, true},
|
{"http", httpLuaCode, true},
|
||||||
|
{"string", stringLuaCode, false},
|
||||||
|
{"table", tableLuaCode, false},
|
||||||
|
{"util", utilLuaCode, true},
|
||||||
{"cookie", cookieLuaCode, true},
|
{"cookie", cookieLuaCode, true},
|
||||||
{"session", sessionLuaCode, true},
|
{"session", sessionLuaCode, true},
|
||||||
{"csrf", csrfLuaCode, true},
|
{"csrf", csrfLuaCode, true},
|
||||||
{"render", renderLuaCode, true},
|
{"render", renderLuaCode, true},
|
||||||
{"json", jsonLuaCode, true},
|
{"json", jsonLuaCode, true},
|
||||||
{"sqlite", sqliteLuaCode, false},
|
|
||||||
{"fs", fsLuaCode, true},
|
{"fs", fsLuaCode, true},
|
||||||
{"util", utilLuaCode, true},
|
|
||||||
{"string", stringLuaCode, false},
|
|
||||||
{"table", tableLuaCode, false},
|
|
||||||
{"crypto", cryptoLuaCode, true},
|
{"crypto", cryptoLuaCode, true},
|
||||||
{"time", timeLuaCode, false},
|
{"time", timeLuaCode, false},
|
||||||
{"math", mathLuaCode, false},
|
{"math", mathLuaCode, false},
|
||||||
{"env", envLuaCode, true},
|
{"env", envLuaCode, true},
|
||||||
|
{"sqlite", sqliteLuaCode, true},
|
||||||
{"timestamp", timestampLuaCode, false},
|
{"timestamp", timestampLuaCode, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -288,7 +288,7 @@ local connection_mt = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return function(db_name)
|
function sqlite(db_name)
|
||||||
if type(db_name) ~= "string" then
|
if type(db_name) ~= "string" then
|
||||||
error("sqlite: database name must be a string", 2)
|
error("sqlite: database name must be a string", 2)
|
||||||
end
|
end
|
||||||
|
@ -249,7 +249,7 @@ func (r *Runner) createState(index int) (*State, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
sb := NewSandbox()
|
sb := NewSandbox()
|
||||||
if err := sb.Setup(L, index == 0); err != nil {
|
if err := sb.Setup(L, index, index == 0); err != nil {
|
||||||
L.Cleanup()
|
L.Cleanup()
|
||||||
L.Close()
|
L.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -22,12 +22,16 @@ func NewSandbox() *Sandbox {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Setup initializes the sandbox in a Lua state
|
// Setup initializes the sandbox in a Lua state
|
||||||
func (s *Sandbox) Setup(state *luajit.State, verbose bool) error {
|
func (s *Sandbox) Setup(state *luajit.State, stateIndex int, verbose bool) error {
|
||||||
// Load all embedded modules and sandbox
|
// Load all embedded modules and sandbox
|
||||||
if err := loadSandboxIntoState(state, verbose); err != nil {
|
if err := loadSandboxIntoState(state, verbose); err != nil {
|
||||||
return fmt.Errorf("failed to load sandbox: %w", err)
|
return fmt.Errorf("failed to load sandbox: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set the state index as a global variable
|
||||||
|
state.PushNumber(float64(stateIndex))
|
||||||
|
state.SetGlobal("__STATE_INDEX")
|
||||||
|
|
||||||
// Pre-compile the executor function for reuse
|
// Pre-compile the executor function for reuse
|
||||||
executorCode := `return __execute`
|
executorCode := `return __execute`
|
||||||
bytecode, err := state.CompileBytecode(executorCode, "executor")
|
bytecode, err := state.CompileBytecode(executorCode, "executor")
|
||||||
|
@ -18,13 +18,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
dbPools = make(map[string]*sqlitex.Pool)
|
dbPools = make(map[string]*sqlitex.Pool)
|
||||||
poolsMu sync.RWMutex
|
poolsMu sync.RWMutex
|
||||||
dataDir string
|
dataDir string
|
||||||
poolSize = 8
|
poolSize = 8
|
||||||
connTimeout = 5 * time.Second
|
connTimeout = 5 * time.Second
|
||||||
stateConnections = make(map[*luajit.State]map[string]*sqlite.Conn)
|
|
||||||
stateConnsMu sync.RWMutex
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func InitSQLite(dir string) {
|
func InitSQLite(dir string) {
|
||||||
@ -116,10 +114,19 @@ func sqlQuery(state *luajit.State) int {
|
|||||||
return state.PushError("sqlite.query: query must be string")
|
return state.PushError("sqlite.query: query must be string")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := getStateConnection(state, dbName)
|
pool, err := getPool(dbName)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("sqlite.query: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := pool.Take(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return state.PushError("sqlite.query: connection timeout: %v", err)
|
return state.PushError("sqlite.query: connection timeout: %v", err)
|
||||||
}
|
}
|
||||||
|
defer pool.Put(conn)
|
||||||
|
|
||||||
var execOpts sqlitex.ExecOptions
|
var execOpts sqlitex.ExecOptions
|
||||||
rows := make([]any, 0, 16)
|
rows := make([]any, 0, 16)
|
||||||
@ -185,11 +192,20 @@ func sqlExec(state *luajit.State) int {
|
|||||||
return state.PushError("sqlite.exec: query must be string")
|
return state.PushError("sqlite.exec: query must be string")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := getStateConnection(state, dbName)
|
pool, err := getPool(dbName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return state.PushError("sqlite.query: connection timeout: %v", err)
|
return state.PushError("sqlite.exec: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := pool.Take(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("sqlite.exec: connection timeout: %v", err)
|
||||||
|
}
|
||||||
|
defer pool.Put(conn)
|
||||||
|
|
||||||
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
|
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
|
||||||
|
|
||||||
if strings.Contains(query, ";") && !hasParams {
|
if strings.Contains(query, ";") && !hasParams {
|
||||||
@ -311,11 +327,20 @@ func sqlGetOne(state *luajit.State) int {
|
|||||||
return state.PushError("sqlite.get_one: query must be string")
|
return state.PushError("sqlite.get_one: query must be string")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := getStateConnection(state, dbName)
|
pool, err := getPool(dbName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return state.PushError("sqlite.query: connection timeout: %v", err)
|
return state.PushError("sqlite.get_one: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := pool.Take(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return state.PushError("sqlite.get_one: connection timeout: %v", err)
|
||||||
|
}
|
||||||
|
defer pool.Put(conn)
|
||||||
|
|
||||||
var execOpts sqlitex.ExecOptions
|
var execOpts sqlitex.ExecOptions
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
|
|
||||||
@ -384,52 +409,3 @@ func RegisterSQLiteFunctions(state *luajit.State) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CleanupStateConnections(luaState *luajit.State) {
|
|
||||||
stateConnsMu.Lock()
|
|
||||||
defer stateConnsMu.Unlock()
|
|
||||||
|
|
||||||
if conns, exists := stateConnections[luaState]; exists {
|
|
||||||
for dbName, conn := range conns {
|
|
||||||
if pool, err := getPool(dbName); err == nil {
|
|
||||||
pool.Put(conn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete(stateConnections, luaState)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getStateConnection(state *luajit.State, dbName string) (*sqlite.Conn, error) {
|
|
||||||
stateConnsMu.RLock()
|
|
||||||
if conns, exists := stateConnections[state]; exists {
|
|
||||||
if conn, exists := conns[dbName]; exists {
|
|
||||||
stateConnsMu.RUnlock()
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stateConnsMu.RUnlock()
|
|
||||||
|
|
||||||
// Get new connection
|
|
||||||
pool, err := getPool(dbName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
conn, err := pool.Take(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cache it
|
|
||||||
stateConnsMu.Lock()
|
|
||||||
if stateConnections[state] == nil {
|
|
||||||
stateConnections[state] = make(map[string]*sqlite.Conn)
|
|
||||||
}
|
|
||||||
stateConnections[state][dbName] = conn
|
|
||||||
stateConnsMu.Unlock()
|
|
||||||
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user