493 lines
12 KiB
Go
493 lines
12 KiB
Go
package runner
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
sqlite "zombiezen.com/go/sqlite"
|
|
"zombiezen.com/go/sqlite/sqlitex"
|
|
|
|
"Moonshark/utils/logger"
|
|
|
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
)
|
|
|
|
// DbPools maintains database connection pools
|
|
var (
|
|
dbPools = make(map[string]*sqlitex.Pool)
|
|
poolsMu sync.RWMutex
|
|
dataDir string
|
|
|
|
// Connection tracking
|
|
activeConns = make(map[string]*TrackedConn)
|
|
activeConnMu sync.RWMutex
|
|
connTimeout = 5 * time.Minute
|
|
)
|
|
|
|
// TrackedConn holds a connection with usage tracking
|
|
type TrackedConn struct {
|
|
Conn *sqlite.Conn
|
|
Pool *sqlitex.Pool
|
|
DBName string
|
|
LastUsed time.Time
|
|
}
|
|
|
|
// generateConnToken creates a unique token for connection tracking
|
|
func generateConnToken() string {
|
|
b := make([]byte, 8)
|
|
rand.Read(b)
|
|
return base64.URLEncoding.EncodeToString(b)
|
|
}
|
|
|
|
// InitSQLite initializes the SQLite subsystem
|
|
func InitSQLite(dir string) {
|
|
dataDir = dir
|
|
logger.Server("SQLite initialized with data directory: %s", dir)
|
|
|
|
// Start connection cleanup goroutine
|
|
go cleanupIdleConnections()
|
|
}
|
|
|
|
// cleanupIdleConnections periodically checks for and removes idle connections
|
|
func cleanupIdleConnections() {
|
|
ticker := time.NewTicker(30 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
now := time.Now()
|
|
|
|
activeConnMu.Lock()
|
|
for token, conn := range activeConns {
|
|
if conn.LastUsed.Add(connTimeout).Before(now) {
|
|
logger.Debug("Closing idle connection: %s (%s)", token, conn.DBName)
|
|
conn.Pool.Put(conn.Conn)
|
|
delete(activeConns, token)
|
|
}
|
|
}
|
|
activeConnMu.Unlock()
|
|
}
|
|
}
|
|
|
|
// CleanupSQLite closes all database connections
|
|
func CleanupSQLite() {
|
|
activeConnMu.Lock()
|
|
for token, conn := range activeConns {
|
|
conn.Pool.Put(conn.Conn)
|
|
delete(activeConns, token)
|
|
}
|
|
activeConnMu.Unlock()
|
|
|
|
poolsMu.Lock()
|
|
defer poolsMu.Unlock()
|
|
|
|
for name, pool := range dbPools {
|
|
if err := pool.Close(); err != nil {
|
|
logger.Error("Failed to close database %s: %v", name, err)
|
|
}
|
|
}
|
|
|
|
dbPools = make(map[string]*sqlitex.Pool)
|
|
logger.Debug("SQLite connections closed")
|
|
}
|
|
|
|
// getPool returns a connection pool for the database
|
|
func getPool(dbName string) (*sqlitex.Pool, error) {
|
|
// Validate database name
|
|
dbName = filepath.Base(dbName)
|
|
if dbName == "" || dbName[0] == '.' {
|
|
return nil, fmt.Errorf("invalid database name")
|
|
}
|
|
|
|
// Check for existing pool
|
|
poolsMu.RLock()
|
|
pool, exists := dbPools[dbName]
|
|
if exists {
|
|
poolsMu.RUnlock()
|
|
return pool, nil
|
|
}
|
|
poolsMu.RUnlock()
|
|
|
|
// Create new pool under write lock
|
|
poolsMu.Lock()
|
|
defer poolsMu.Unlock()
|
|
|
|
// Double-check if a pool was created while waiting for lock
|
|
if pool, exists = dbPools[dbName]; exists {
|
|
return pool, nil
|
|
}
|
|
|
|
// Create new pool
|
|
dbPath := filepath.Join(dataDir, dbName+".db")
|
|
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
|
}
|
|
|
|
dbPools[dbName] = pool
|
|
return pool, nil
|
|
}
|
|
|
|
// getConnection retrieves or creates a tracked connection
|
|
func getConnection(token, dbName string) (*TrackedConn, string, error) {
|
|
// If token is provided, try to get existing connection
|
|
if token != "" {
|
|
activeConnMu.RLock()
|
|
conn, exists := activeConns[token]
|
|
activeConnMu.RUnlock()
|
|
|
|
if exists {
|
|
conn.LastUsed = time.Now()
|
|
return conn, token, nil
|
|
}
|
|
}
|
|
|
|
// Token not provided or connection not found, create new
|
|
pool, err := getPool(dbName)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
conn, err := pool.Take(context.Background())
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
// Generate new token
|
|
newToken := generateConnToken()
|
|
|
|
trackedConn := &TrackedConn{
|
|
Conn: conn,
|
|
Pool: pool,
|
|
DBName: dbName,
|
|
LastUsed: time.Now(),
|
|
}
|
|
|
|
activeConnMu.Lock()
|
|
activeConns[newToken] = trackedConn
|
|
activeConnMu.Unlock()
|
|
|
|
return trackedConn, newToken, nil
|
|
}
|
|
|
|
// releaseConnection releases a connection back to the pool
|
|
func releaseConnection(token string) bool {
|
|
activeConnMu.Lock()
|
|
defer activeConnMu.Unlock()
|
|
|
|
conn, exists := activeConns[token]
|
|
if !exists {
|
|
return false
|
|
}
|
|
|
|
conn.Pool.Put(conn.Conn)
|
|
delete(activeConns, token)
|
|
return true
|
|
}
|
|
|
|
// sqlQuery executes a SQL query and returns results
|
|
func sqlQuery(state *luajit.State) int {
|
|
// Get required parameters
|
|
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
|
|
state.PushString("sqlite.query: requires database name and query")
|
|
return -1
|
|
}
|
|
|
|
dbName := state.ToString(1)
|
|
query := state.ToString(2)
|
|
|
|
// Get connection token (optional)
|
|
var connToken string
|
|
if state.GetTop() >= 4 && state.IsString(4) {
|
|
connToken = state.ToString(4)
|
|
}
|
|
|
|
// Get connection
|
|
trackedConn, newToken, err := getConnection(connToken, dbName)
|
|
if err != nil {
|
|
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
|
return -1
|
|
}
|
|
|
|
conn := trackedConn.Conn
|
|
|
|
// Create execution options
|
|
var execOpts sqlitex.ExecOptions
|
|
rows := make([]map[string]any, 0, 16)
|
|
|
|
// Set up parameters if provided
|
|
if state.GetTop() >= 3 && !state.IsNil(3) {
|
|
if state.IsTable(3) {
|
|
params, err := state.ToTable(3)
|
|
if err != nil {
|
|
state.PushString(fmt.Sprintf("sqlite.query: invalid parameters: %s", err.Error()))
|
|
return -1
|
|
}
|
|
|
|
// Check for array-style params
|
|
if arr, ok := params[""]; ok {
|
|
if arrParams, ok := arr.([]any); ok {
|
|
execOpts.Args = arrParams
|
|
} else if floatArr, ok := arr.([]float64); ok {
|
|
args := make([]any, len(floatArr))
|
|
for i, v := range floatArr {
|
|
args[i] = v
|
|
}
|
|
execOpts.Args = args
|
|
}
|
|
} else {
|
|
// Named parameters
|
|
named := make(map[string]any, len(params))
|
|
for k, v := range params {
|
|
if len(k) > 0 && k[0] != ':' {
|
|
named[":"+k] = v
|
|
} else {
|
|
named[k] = v
|
|
}
|
|
}
|
|
execOpts.Named = named
|
|
}
|
|
} else {
|
|
// Positional parameters
|
|
count := state.GetTop() - 2
|
|
if state.IsString(4) {
|
|
count-- // Don't include connection token
|
|
}
|
|
args := make([]any, count)
|
|
for i := range count {
|
|
idx := i + 3
|
|
switch state.GetType(idx) {
|
|
case luajit.TypeNumber:
|
|
args[i] = state.ToNumber(idx)
|
|
case luajit.TypeString:
|
|
args[i] = state.ToString(idx)
|
|
case luajit.TypeBoolean:
|
|
args[i] = state.ToBoolean(idx)
|
|
case luajit.TypeNil:
|
|
args[i] = nil
|
|
default:
|
|
val, err := state.ToValue(idx)
|
|
if err != nil {
|
|
state.PushString(fmt.Sprintf("sqlite.query: invalid parameter %d: %s", i+1, err.Error()))
|
|
return -1
|
|
}
|
|
args[i] = val
|
|
}
|
|
}
|
|
execOpts.Args = args
|
|
}
|
|
}
|
|
|
|
// Set up result function
|
|
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
|
row := make(map[string]any)
|
|
colCount := stmt.ColumnCount()
|
|
|
|
for i := range colCount {
|
|
colName := stmt.ColumnName(i)
|
|
switch stmt.ColumnType(i) {
|
|
case sqlite.TypeInteger:
|
|
row[colName] = stmt.ColumnInt64(i)
|
|
case sqlite.TypeFloat:
|
|
row[colName] = stmt.ColumnFloat(i)
|
|
case sqlite.TypeText:
|
|
row[colName] = stmt.ColumnText(i)
|
|
case sqlite.TypeBlob:
|
|
blobSize := stmt.ColumnLen(i)
|
|
buf := make([]byte, blobSize)
|
|
row[colName] = stmt.ColumnBytes(i, buf)
|
|
case sqlite.TypeNull:
|
|
row[colName] = nil
|
|
}
|
|
}
|
|
rows = append(rows, row)
|
|
return nil
|
|
}
|
|
|
|
// Execute query
|
|
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
|
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
|
return -1
|
|
}
|
|
|
|
// Create result table
|
|
state.NewTable()
|
|
for i, row := range rows {
|
|
state.PushNumber(float64(i + 1))
|
|
if err := state.PushTable(row); err != nil {
|
|
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
|
return -1
|
|
}
|
|
state.SetTable(-3)
|
|
}
|
|
|
|
// Return connection token
|
|
state.PushString(newToken)
|
|
|
|
return 2
|
|
}
|
|
|
|
// sqlExec executes a SQL statement without returning results
|
|
func sqlExec(state *luajit.State) int {
|
|
// Get required parameters
|
|
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
|
|
state.PushString("sqlite.exec: requires database name and query")
|
|
return -1
|
|
}
|
|
|
|
dbName := state.ToString(1)
|
|
query := state.ToString(2)
|
|
|
|
// Get connection token (optional)
|
|
var connToken string
|
|
if state.GetTop() >= 4 && state.IsString(4) {
|
|
connToken = state.ToString(4)
|
|
}
|
|
|
|
// Get connection
|
|
trackedConn, newToken, err := getConnection(connToken, dbName)
|
|
if err != nil {
|
|
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
|
return -1
|
|
}
|
|
|
|
conn := trackedConn.Conn
|
|
|
|
// Check if parameters are provided
|
|
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
|
|
|
|
// Fast path for multi-statement scripts - use ExecScript
|
|
if strings.Contains(query, ";") && !hasParams {
|
|
if err := sqlitex.ExecScript(conn, query); err != nil {
|
|
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
|
return -1
|
|
}
|
|
state.PushNumber(float64(conn.Changes()))
|
|
state.PushString(newToken)
|
|
return 2
|
|
}
|
|
|
|
// Fast path for simple queries with no parameters
|
|
if !hasParams {
|
|
// Use Execute for simple statements without parameters
|
|
if err := sqlitex.Execute(conn, query, nil); err != nil {
|
|
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
|
return -1
|
|
}
|
|
state.PushNumber(float64(conn.Changes()))
|
|
state.PushString(newToken)
|
|
return 2
|
|
}
|
|
|
|
// Create execution options for parameterized query
|
|
var execOpts sqlitex.ExecOptions
|
|
|
|
// Set up parameters
|
|
if state.IsTable(3) {
|
|
params, err := state.ToTable(3)
|
|
if err != nil {
|
|
state.PushString(fmt.Sprintf("sqlite.exec: invalid parameters: %s", err.Error()))
|
|
return -1
|
|
}
|
|
|
|
// Check for array-style params
|
|
if arr, ok := params[""]; ok {
|
|
if arrParams, ok := arr.([]any); ok {
|
|
execOpts.Args = arrParams
|
|
} else if floatArr, ok := arr.([]float64); ok {
|
|
args := make([]any, len(floatArr))
|
|
for i, v := range floatArr {
|
|
args[i] = v
|
|
}
|
|
execOpts.Args = args
|
|
}
|
|
} else {
|
|
// Named parameters
|
|
named := make(map[string]any, len(params))
|
|
for k, v := range params {
|
|
if len(k) > 0 && k[0] != ':' {
|
|
named[":"+k] = v
|
|
} else {
|
|
named[k] = v
|
|
}
|
|
}
|
|
execOpts.Named = named
|
|
}
|
|
} else {
|
|
// Positional parameters
|
|
count := state.GetTop() - 2
|
|
if state.IsString(4) {
|
|
count-- // Don't include connection token
|
|
}
|
|
args := make([]any, count)
|
|
for i := range count {
|
|
idx := i + 3
|
|
switch state.GetType(idx) {
|
|
case luajit.TypeNumber:
|
|
args[i] = state.ToNumber(idx)
|
|
case luajit.TypeString:
|
|
args[i] = state.ToString(idx)
|
|
case luajit.TypeBoolean:
|
|
args[i] = state.ToBoolean(idx)
|
|
case luajit.TypeNil:
|
|
args[i] = nil
|
|
default:
|
|
val, err := state.ToValue(idx)
|
|
if err != nil {
|
|
state.PushString(fmt.Sprintf("sqlite.exec: invalid parameter %d: %s", i+1, err.Error()))
|
|
return -1
|
|
}
|
|
args[i] = val
|
|
}
|
|
}
|
|
execOpts.Args = args
|
|
}
|
|
|
|
// Execute with parameters
|
|
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
|
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
|
return -1
|
|
}
|
|
|
|
// Return affected rows and connection token
|
|
state.PushNumber(float64(conn.Changes()))
|
|
state.PushString(newToken)
|
|
return 2
|
|
}
|
|
|
|
// sqlClose releases a connection back to the pool
|
|
func sqlClose(state *luajit.State) int {
|
|
if state.GetTop() < 1 || !state.IsString(1) {
|
|
state.PushString("sqlite.close: requires connection token")
|
|
return -1
|
|
}
|
|
|
|
token := state.ToString(1)
|
|
if releaseConnection(token) {
|
|
state.PushBoolean(true)
|
|
} else {
|
|
state.PushBoolean(false)
|
|
}
|
|
|
|
return 1
|
|
}
|
|
|
|
// RegisterSQLiteFunctions registers SQLite functions with the Lua state
|
|
func RegisterSQLiteFunctions(state *luajit.State) error {
|
|
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
|
|
return err
|
|
}
|
|
if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil {
|
|
return err
|
|
}
|
|
if err := state.RegisterGoFunction("__sqlite_close", sqlClose); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|