move sqlite in runner to its own package
This commit is contained in:
parent
cf38b947e1
commit
14fcd7894b
@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"Moonshark/router"
|
"Moonshark/router"
|
||||||
"Moonshark/runner/lualibs"
|
"Moonshark/runner/lualibs"
|
||||||
|
"Moonshark/runner/sqlite"
|
||||||
"Moonshark/sessions"
|
"Moonshark/sessions"
|
||||||
"Moonshark/utils/logger"
|
"Moonshark/utils/logger"
|
||||||
|
|
||||||
@ -72,15 +73,15 @@ func NewRunner(poolSize int, dataDir, fsDir string, libDirs []string) (*Runner,
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
lualibs.InitSQLite(dataDir)
|
sqlite.InitSQLite(dataDir)
|
||||||
lualibs.InitFS(fsDir)
|
lualibs.InitFS(fsDir)
|
||||||
lualibs.SetSQLitePoolSize(poolSize)
|
sqlite.SetSQLitePoolSize(poolSize)
|
||||||
|
|
||||||
r.states = make([]*State, poolSize)
|
r.states = make([]*State, poolSize)
|
||||||
r.statePool = make(chan int, poolSize)
|
r.statePool = make(chan int, poolSize)
|
||||||
|
|
||||||
if err := r.initStates(); err != nil {
|
if err := r.initStates(); err != nil {
|
||||||
lualibs.CleanupSQLite()
|
sqlite.CleanupSQLite()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -314,7 +315,7 @@ cleanup:
|
|||||||
}
|
}
|
||||||
|
|
||||||
lualibs.CleanupFS()
|
lualibs.CleanupFS()
|
||||||
lualibs.CleanupSQLite()
|
sqlite.CleanupSQLite()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package runner
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"Moonshark/runner/lualibs"
|
"Moonshark/runner/lualibs"
|
||||||
|
"Moonshark/runner/sqlite"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"maps"
|
"maps"
|
||||||
@ -201,7 +202,7 @@ func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := lualibs.RegisterSQLiteFunctions(state); err != nil {
|
if err := sqlite.RegisterSQLiteFunctions(state); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package lualibs
|
package sqlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -23,6 +23,8 @@ var (
|
|||||||
dataDir string
|
dataDir string
|
||||||
poolSize = 8
|
poolSize = 8
|
||||||
connTimeout = 5 * time.Second
|
connTimeout = 5 * time.Second
|
||||||
|
stateConnections = make(map[*luajit.State]map[string]*sqlite.Conn)
|
||||||
|
stateConnsMu sync.RWMutex
|
||||||
)
|
)
|
||||||
|
|
||||||
func InitSQLite(dir string) {
|
func InitSQLite(dir string) {
|
||||||
@ -114,19 +116,10 @@ func sqlQuery(state *luajit.State) int {
|
|||||||
return state.PushError("sqlite.query: query must be string")
|
return state.PushError("sqlite.query: query must be string")
|
||||||
}
|
}
|
||||||
|
|
||||||
pool, err := getPool(dbName)
|
conn, err := getStateConnection(state, dbName)
|
||||||
if err != nil {
|
|
||||||
return state.PushError("sqlite.query: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
conn, err := pool.Take(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return state.PushError("sqlite.query: connection timeout: %v", err)
|
return state.PushError("sqlite.query: connection timeout: %v", err)
|
||||||
}
|
}
|
||||||
defer pool.Put(conn)
|
|
||||||
|
|
||||||
var execOpts sqlitex.ExecOptions
|
var execOpts sqlitex.ExecOptions
|
||||||
rows := make([]any, 0, 16)
|
rows := make([]any, 0, 16)
|
||||||
@ -192,20 +185,11 @@ func sqlExec(state *luajit.State) int {
|
|||||||
return state.PushError("sqlite.exec: query must be string")
|
return state.PushError("sqlite.exec: query must be string")
|
||||||
}
|
}
|
||||||
|
|
||||||
pool, err := getPool(dbName)
|
conn, err := getStateConnection(state, dbName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return state.PushError("sqlite.exec: %v", err)
|
return state.PushError("sqlite.query: connection timeout: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
conn, err := pool.Take(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return state.PushError("sqlite.exec: connection timeout: %v", err)
|
|
||||||
}
|
|
||||||
defer pool.Put(conn)
|
|
||||||
|
|
||||||
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
|
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
|
||||||
|
|
||||||
if strings.Contains(query, ";") && !hasParams {
|
if strings.Contains(query, ";") && !hasParams {
|
||||||
@ -327,20 +311,11 @@ func sqlGetOne(state *luajit.State) int {
|
|||||||
return state.PushError("sqlite.get_one: query must be string")
|
return state.PushError("sqlite.get_one: query must be string")
|
||||||
}
|
}
|
||||||
|
|
||||||
pool, err := getPool(dbName)
|
conn, err := getStateConnection(state, dbName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return state.PushError("sqlite.get_one: %v", err)
|
return state.PushError("sqlite.query: connection timeout: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
conn, err := pool.Take(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return state.PushError("sqlite.get_one: connection timeout: %v", err)
|
|
||||||
}
|
|
||||||
defer pool.Put(conn)
|
|
||||||
|
|
||||||
var execOpts sqlitex.ExecOptions
|
var execOpts sqlitex.ExecOptions
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
|
|
||||||
@ -409,3 +384,52 @@ func RegisterSQLiteFunctions(state *luajit.State) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CleanupStateConnections(luaState *luajit.State) {
|
||||||
|
stateConnsMu.Lock()
|
||||||
|
defer stateConnsMu.Unlock()
|
||||||
|
|
||||||
|
if conns, exists := stateConnections[luaState]; exists {
|
||||||
|
for dbName, conn := range conns {
|
||||||
|
if pool, err := getPool(dbName); err == nil {
|
||||||
|
pool.Put(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(stateConnections, luaState)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStateConnection(state *luajit.State, dbName string) (*sqlite.Conn, error) {
|
||||||
|
stateConnsMu.RLock()
|
||||||
|
if conns, exists := stateConnections[state]; exists {
|
||||||
|
if conn, exists := conns[dbName]; exists {
|
||||||
|
stateConnsMu.RUnlock()
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stateConnsMu.RUnlock()
|
||||||
|
|
||||||
|
// Get new connection
|
||||||
|
pool, err := getPool(dbName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := pool.Take(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache it
|
||||||
|
stateConnsMu.Lock()
|
||||||
|
if stateConnections[state] == nil {
|
||||||
|
stateConnections[state] = make(map[string]*sqlite.Conn)
|
||||||
|
}
|
||||||
|
stateConnections[state][dbName] = conn
|
||||||
|
stateConnsMu.Unlock()
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user