diff --git a/runner/runner.go b/runner/runner.go index 87999c4..80611c6 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -13,6 +13,7 @@ import ( "Moonshark/router" "Moonshark/runner/lualibs" + "Moonshark/runner/sqlite" "Moonshark/sessions" "Moonshark/utils/logger" @@ -72,15 +73,15 @@ func NewRunner(poolSize int, dataDir, fsDir string, libDirs []string) (*Runner, }, } - lualibs.InitSQLite(dataDir) + sqlite.InitSQLite(dataDir) lualibs.InitFS(fsDir) - lualibs.SetSQLitePoolSize(poolSize) + sqlite.SetSQLitePoolSize(poolSize) r.states = make([]*State, poolSize) r.statePool = make(chan int, poolSize) if err := r.initStates(); err != nil { - lualibs.CleanupSQLite() + sqlite.CleanupSQLite() return nil, err } @@ -314,7 +315,7 @@ cleanup: } lualibs.CleanupFS() - lualibs.CleanupSQLite() + sqlite.CleanupSQLite() return nil } diff --git a/runner/sandbox.go b/runner/sandbox.go index ff33d8c..5dcd76e 100644 --- a/runner/sandbox.go +++ b/runner/sandbox.go @@ -2,6 +2,7 @@ package runner import ( "Moonshark/runner/lualibs" + "Moonshark/runner/sqlite" "fmt" "maps" @@ -201,7 +202,7 @@ func (s *Sandbox) registerCoreFunctions(state *luajit.State) error { return err } - if err := lualibs.RegisterSQLiteFunctions(state); err != nil { + if err := sqlite.RegisterSQLiteFunctions(state); err != nil { return err } diff --git a/runner/lualibs/sqlite.go b/runner/sqlite/sqlite.go similarity index 85% rename from runner/lualibs/sqlite.go rename to runner/sqlite/sqlite.go index 00c80f0..e2887b7 100644 --- a/runner/lualibs/sqlite.go +++ b/runner/sqlite/sqlite.go @@ -1,4 +1,4 @@ -package lualibs +package sqlite import ( "context" @@ -18,11 +18,13 @@ import ( ) var ( - dbPools = make(map[string]*sqlitex.Pool) - poolsMu sync.RWMutex - dataDir string - poolSize = 8 - connTimeout = 5 * time.Second + dbPools = make(map[string]*sqlitex.Pool) + poolsMu sync.RWMutex + dataDir string + poolSize = 8 + connTimeout = 5 * time.Second + stateConnections = make(map[*luajit.State]map[string]*sqlite.Conn) + stateConnsMu sync.RWMutex ) func InitSQLite(dir string) { @@ -114,19 +116,10 @@ func sqlQuery(state *luajit.State) int { return state.PushError("sqlite.query: query must be string") } - pool, err := getPool(dbName) - if err != nil { - return state.PushError("sqlite.query: %v", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), connTimeout) - defer cancel() - - conn, err := pool.Take(ctx) + conn, err := getStateConnection(state, dbName) if err != nil { return state.PushError("sqlite.query: connection timeout: %v", err) } - defer pool.Put(conn) var execOpts sqlitex.ExecOptions rows := make([]any, 0, 16) @@ -192,20 +185,11 @@ func sqlExec(state *luajit.State) int { return state.PushError("sqlite.exec: query must be string") } - pool, err := getPool(dbName) + conn, err := getStateConnection(state, dbName) if err != nil { - return state.PushError("sqlite.exec: %v", err) + return state.PushError("sqlite.query: connection timeout: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), connTimeout) - defer cancel() - - conn, err := pool.Take(ctx) - if err != nil { - return state.PushError("sqlite.exec: connection timeout: %v", err) - } - defer pool.Put(conn) - hasParams := state.GetTop() >= 3 && !state.IsNil(3) if strings.Contains(query, ";") && !hasParams { @@ -327,20 +311,11 @@ func sqlGetOne(state *luajit.State) int { return state.PushError("sqlite.get_one: query must be string") } - pool, err := getPool(dbName) + conn, err := getStateConnection(state, dbName) if err != nil { - return state.PushError("sqlite.get_one: %v", err) + return state.PushError("sqlite.query: connection timeout: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), connTimeout) - defer cancel() - - conn, err := pool.Take(ctx) - if err != nil { - return state.PushError("sqlite.get_one: connection timeout: %v", err) - } - defer pool.Put(conn) - var execOpts sqlitex.ExecOptions var result map[string]any @@ -409,3 +384,52 @@ func RegisterSQLiteFunctions(state *luajit.State) error { } return nil } + +func CleanupStateConnections(luaState *luajit.State) { + stateConnsMu.Lock() + defer stateConnsMu.Unlock() + + if conns, exists := stateConnections[luaState]; exists { + for dbName, conn := range conns { + if pool, err := getPool(dbName); err == nil { + pool.Put(conn) + } + } + delete(stateConnections, luaState) + } +} + +func getStateConnection(state *luajit.State, dbName string) (*sqlite.Conn, error) { + stateConnsMu.RLock() + if conns, exists := stateConnections[state]; exists { + if conn, exists := conns[dbName]; exists { + stateConnsMu.RUnlock() + return conn, nil + } + } + stateConnsMu.RUnlock() + + // Get new connection + pool, err := getPool(dbName) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), connTimeout) + defer cancel() + + conn, err := pool.Take(ctx) + if err != nil { + return nil, err + } + + // Cache it + stateConnsMu.Lock() + if stateConnections[state] == nil { + stateConnections[state] = make(map[string]*sqlite.Conn) + } + stateConnections[state][dbName] = conn + stateConnsMu.Unlock() + + return conn, nil +}