move sqlite in runner to its own package

This commit is contained in:
Sky Johnson 2025-06-05 12:58:05 -05:00
parent cf38b947e1
commit 14fcd7894b
3 changed files with 69 additions and 43 deletions

View File

@ -13,6 +13,7 @@ import (
"Moonshark/router" "Moonshark/router"
"Moonshark/runner/lualibs" "Moonshark/runner/lualibs"
"Moonshark/runner/sqlite"
"Moonshark/sessions" "Moonshark/sessions"
"Moonshark/utils/logger" "Moonshark/utils/logger"
@ -72,15 +73,15 @@ func NewRunner(poolSize int, dataDir, fsDir string, libDirs []string) (*Runner,
}, },
} }
lualibs.InitSQLite(dataDir) sqlite.InitSQLite(dataDir)
lualibs.InitFS(fsDir) lualibs.InitFS(fsDir)
lualibs.SetSQLitePoolSize(poolSize) sqlite.SetSQLitePoolSize(poolSize)
r.states = make([]*State, poolSize) r.states = make([]*State, poolSize)
r.statePool = make(chan int, poolSize) r.statePool = make(chan int, poolSize)
if err := r.initStates(); err != nil { if err := r.initStates(); err != nil {
lualibs.CleanupSQLite() sqlite.CleanupSQLite()
return nil, err return nil, err
} }
@ -314,7 +315,7 @@ cleanup:
} }
lualibs.CleanupFS() lualibs.CleanupFS()
lualibs.CleanupSQLite() sqlite.CleanupSQLite()
return nil return nil
} }

View File

@ -2,6 +2,7 @@ package runner
import ( import (
"Moonshark/runner/lualibs" "Moonshark/runner/lualibs"
"Moonshark/runner/sqlite"
"fmt" "fmt"
"maps" "maps"
@ -201,7 +202,7 @@ func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
return err return err
} }
if err := lualibs.RegisterSQLiteFunctions(state); err != nil { if err := sqlite.RegisterSQLiteFunctions(state); err != nil {
return err return err
} }

View File

@ -1,4 +1,4 @@
package lualibs package sqlite
import ( import (
"context" "context"
@ -18,11 +18,13 @@ import (
) )
var ( var (
dbPools = make(map[string]*sqlitex.Pool) dbPools = make(map[string]*sqlitex.Pool)
poolsMu sync.RWMutex poolsMu sync.RWMutex
dataDir string dataDir string
poolSize = 8 poolSize = 8
connTimeout = 5 * time.Second connTimeout = 5 * time.Second
stateConnections = make(map[*luajit.State]map[string]*sqlite.Conn)
stateConnsMu sync.RWMutex
) )
func InitSQLite(dir string) { func InitSQLite(dir string) {
@ -114,19 +116,10 @@ func sqlQuery(state *luajit.State) int {
return state.PushError("sqlite.query: query must be string") return state.PushError("sqlite.query: query must be string")
} }
pool, err := getPool(dbName) conn, err := getStateConnection(state, dbName)
if err != nil {
return state.PushError("sqlite.query: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil { if err != nil {
return state.PushError("sqlite.query: connection timeout: %v", err) return state.PushError("sqlite.query: connection timeout: %v", err)
} }
defer pool.Put(conn)
var execOpts sqlitex.ExecOptions var execOpts sqlitex.ExecOptions
rows := make([]any, 0, 16) rows := make([]any, 0, 16)
@ -192,20 +185,11 @@ func sqlExec(state *luajit.State) int {
return state.PushError("sqlite.exec: query must be string") return state.PushError("sqlite.exec: query must be string")
} }
pool, err := getPool(dbName) conn, err := getStateConnection(state, dbName)
if err != nil { if err != nil {
return state.PushError("sqlite.exec: %v", err) return state.PushError("sqlite.query: connection timeout: %v", err)
} }
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
return state.PushError("sqlite.exec: connection timeout: %v", err)
}
defer pool.Put(conn)
hasParams := state.GetTop() >= 3 && !state.IsNil(3) hasParams := state.GetTop() >= 3 && !state.IsNil(3)
if strings.Contains(query, ";") && !hasParams { if strings.Contains(query, ";") && !hasParams {
@ -327,20 +311,11 @@ func sqlGetOne(state *luajit.State) int {
return state.PushError("sqlite.get_one: query must be string") return state.PushError("sqlite.get_one: query must be string")
} }
pool, err := getPool(dbName) conn, err := getStateConnection(state, dbName)
if err != nil { if err != nil {
return state.PushError("sqlite.get_one: %v", err) return state.PushError("sqlite.query: connection timeout: %v", err)
} }
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
return state.PushError("sqlite.get_one: connection timeout: %v", err)
}
defer pool.Put(conn)
var execOpts sqlitex.ExecOptions var execOpts sqlitex.ExecOptions
var result map[string]any var result map[string]any
@ -409,3 +384,52 @@ func RegisterSQLiteFunctions(state *luajit.State) error {
} }
return nil return nil
} }
func CleanupStateConnections(luaState *luajit.State) {
stateConnsMu.Lock()
defer stateConnsMu.Unlock()
if conns, exists := stateConnections[luaState]; exists {
for dbName, conn := range conns {
if pool, err := getPool(dbName); err == nil {
pool.Put(conn)
}
}
delete(stateConnections, luaState)
}
}
func getStateConnection(state *luajit.State, dbName string) (*sqlite.Conn, error) {
stateConnsMu.RLock()
if conns, exists := stateConnections[state]; exists {
if conn, exists := conns[dbName]; exists {
stateConnsMu.RUnlock()
return conn, nil
}
}
stateConnsMu.RUnlock()
// Get new connection
pool, err := getPool(dbName)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
return nil, err
}
// Cache it
stateConnsMu.Lock()
if stateConnections[state] == nil {
stateConnections[state] = make(map[string]*sqlite.Conn)
}
stateConnections[state][dbName] = conn
stateConnsMu.Unlock()
return conn, nil
}