394 lines
9.1 KiB
Go
394 lines
9.1 KiB
Go
package runner
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
|
|
sqlite "zombiezen.com/go/sqlite"
|
|
"zombiezen.com/go/sqlite/sqlitex"
|
|
|
|
"Moonshark/core/utils/logger"
|
|
|
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
)
|
|
|
|
// SQLiteConnection tracks an active connection
|
|
type SQLiteConnection struct {
|
|
DbName string
|
|
Conn *sqlite.Conn
|
|
Pool *sqlitex.Pool
|
|
}
|
|
|
|
// SQLiteManager handles database connections
|
|
type SQLiteManager struct {
|
|
mu sync.RWMutex
|
|
pools map[string]*sqlitex.Pool
|
|
activeConns map[string]*SQLiteConnection
|
|
dataDir string
|
|
}
|
|
|
|
// Global manager
|
|
var sqliteManager *SQLiteManager
|
|
|
|
// InitSQLite initializes the SQLite manager
|
|
func InitSQLite(dataDir string) {
|
|
sqliteManager = &SQLiteManager{
|
|
pools: make(map[string]*sqlitex.Pool),
|
|
activeConns: make(map[string]*SQLiteConnection),
|
|
dataDir: dataDir,
|
|
}
|
|
logger.Server("SQLite initialized with data directory: %s", dataDir)
|
|
}
|
|
|
|
// CleanupSQLite closes all database connections
|
|
func CleanupSQLite() {
|
|
if sqliteManager == nil {
|
|
return
|
|
}
|
|
|
|
sqliteManager.mu.Lock()
|
|
defer sqliteManager.mu.Unlock()
|
|
|
|
// Release all active connections
|
|
for id, conn := range sqliteManager.activeConns {
|
|
if conn.Pool != nil {
|
|
conn.Pool.Put(conn.Conn)
|
|
}
|
|
delete(sqliteManager.activeConns, id)
|
|
}
|
|
|
|
// Close all pools
|
|
for name, pool := range sqliteManager.pools {
|
|
if err := pool.Close(); err != nil {
|
|
logger.Error("Failed to close database %s: %v", name, err)
|
|
}
|
|
}
|
|
|
|
sqliteManager.pools = nil
|
|
sqliteManager.activeConns = nil
|
|
logger.Debug("SQLite connections closed")
|
|
}
|
|
|
|
// ReleaseActiveConnections returns all active connections to their pools
|
|
func ReleaseActiveConnections(state *luajit.State) {
|
|
if sqliteManager == nil {
|
|
return
|
|
}
|
|
|
|
sqliteManager.mu.Lock()
|
|
defer sqliteManager.mu.Unlock()
|
|
|
|
// Get active connections table from Lua
|
|
state.GetGlobal("__active_sqlite_connections")
|
|
if !state.IsTable(-1) {
|
|
state.Pop(1)
|
|
return
|
|
}
|
|
|
|
// Iterate through active connections
|
|
state.PushNil() // Start iteration
|
|
for state.Next(-2) {
|
|
// Stack now has key at -2 and value at -1
|
|
if state.IsTable(-1) {
|
|
state.GetField(-1, "id")
|
|
if state.IsString(-1) {
|
|
connID := state.ToString(-1)
|
|
|
|
// Release connection from Go side
|
|
if conn, exists := sqliteManager.activeConns[connID]; exists {
|
|
if conn.Pool != nil {
|
|
conn.Pool.Put(conn.Conn)
|
|
}
|
|
delete(sqliteManager.activeConns, connID)
|
|
}
|
|
}
|
|
state.Pop(1) // Pop connection id
|
|
}
|
|
state.Pop(1) // Pop value, leave key for next iteration
|
|
}
|
|
|
|
// Clear the active connections table
|
|
state.PushNil()
|
|
state.SetGlobal("__active_sqlite_connections")
|
|
}
|
|
|
|
// getPool returns a connection pool for the specified database
|
|
func getPool(dbName string) (*sqlitex.Pool, error) {
|
|
if sqliteManager == nil {
|
|
return nil, errors.New("SQLite not initialized")
|
|
}
|
|
|
|
// Validate database name
|
|
dbName = filepath.Base(dbName)
|
|
if dbName == "" || dbName[0] == '.' {
|
|
return nil, errors.New("invalid database name")
|
|
}
|
|
|
|
// Check for existing pool
|
|
sqliteManager.mu.RLock()
|
|
pool, exists := sqliteManager.pools[dbName]
|
|
sqliteManager.mu.RUnlock()
|
|
|
|
if exists {
|
|
return pool, nil
|
|
}
|
|
|
|
// Create new pool
|
|
sqliteManager.mu.Lock()
|
|
defer sqliteManager.mu.Unlock()
|
|
|
|
// Double check if another goroutine created it
|
|
if pool, exists = sqliteManager.pools[dbName]; exists {
|
|
return pool, nil
|
|
}
|
|
|
|
// Create database file path
|
|
dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db")
|
|
|
|
// Create the pool
|
|
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{})
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
|
}
|
|
|
|
sqliteManager.pools[dbName] = pool
|
|
return pool, nil
|
|
}
|
|
|
|
// getConnection returns a connection from the pool
|
|
func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, error) {
|
|
// Check for existing connection first
|
|
sqliteManager.mu.RLock()
|
|
conn, exists := sqliteManager.activeConns[connID]
|
|
sqliteManager.mu.RUnlock()
|
|
|
|
if exists {
|
|
return conn.Conn, conn.Pool, nil
|
|
}
|
|
|
|
// Get the pool
|
|
pool, err := getPool(dbName)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Get a connection
|
|
dbConn := pool.Get(nil)
|
|
if dbConn == nil {
|
|
return nil, nil, errors.New("failed to get connection from pool")
|
|
}
|
|
|
|
// Store connection
|
|
sqliteManager.mu.Lock()
|
|
sqliteManager.activeConns[connID] = &SQLiteConnection{
|
|
DbName: dbName,
|
|
Conn: dbConn,
|
|
Pool: pool,
|
|
}
|
|
sqliteManager.mu.Unlock()
|
|
|
|
return dbConn, pool, nil
|
|
}
|
|
|
|
// luaSQLQuery executes a SQL query and returns results to Lua
|
|
func luaSQLQuery(state *luajit.State) int {
|
|
// Get database name
|
|
if !state.IsString(1) {
|
|
state.PushString("sqlite.query: database name must be a string")
|
|
return -1
|
|
}
|
|
dbName := state.ToString(1)
|
|
|
|
// Get query
|
|
if !state.IsString(2) {
|
|
state.PushString("sqlite.query: query must be a string")
|
|
return -1
|
|
}
|
|
query := state.ToString(2)
|
|
|
|
// Get connection ID (optional for compatibility)
|
|
var connID string
|
|
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
|
|
connID = state.ToString(4)
|
|
} else {
|
|
// Generate a temporary connection ID
|
|
connID = fmt.Sprintf("temp_%p", &query)
|
|
}
|
|
|
|
// Get parameters (optional)
|
|
var params map[string]any
|
|
if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
|
|
var err error
|
|
params, err = state.ToTable(3)
|
|
if err != nil {
|
|
state.PushString("sqlite.query: failed to parse parameters: " + err.Error())
|
|
return -1
|
|
}
|
|
}
|
|
|
|
// Get connection
|
|
conn, pool, err := getConnection(dbName, connID)
|
|
if err != nil {
|
|
state.PushString("sqlite.query: " + err.Error())
|
|
return -1
|
|
}
|
|
|
|
// For temporary connections, defer release
|
|
if !strings.HasPrefix(connID, "temp_") {
|
|
defer pool.Put(conn)
|
|
|
|
// Remove from active connections
|
|
sqliteManager.mu.Lock()
|
|
delete(sqliteManager.activeConns, connID)
|
|
sqliteManager.mu.Unlock()
|
|
}
|
|
|
|
// Execute query and collect results
|
|
var rows []map[string]any
|
|
|
|
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
|
|
Named: params, // Using Named for named parameters
|
|
ResultFunc: func(stmt *sqlite.Stmt) error {
|
|
row := make(map[string]any)
|
|
columnCount := stmt.ColumnCount()
|
|
|
|
for i := 0; i < columnCount; i++ {
|
|
columnName := stmt.ColumnName(i)
|
|
columnType := stmt.ColumnType(i)
|
|
|
|
switch columnType {
|
|
case sqlite.TypeInteger:
|
|
row[columnName] = stmt.ColumnInt64(i)
|
|
case sqlite.TypeFloat:
|
|
row[columnName] = stmt.ColumnFloat(i)
|
|
case sqlite.TypeText:
|
|
row[columnName] = stmt.ColumnText(i)
|
|
case sqlite.TypeBlob:
|
|
blobSize := stmt.ColumnLen(i)
|
|
buf := make([]byte, blobSize)
|
|
blob := stmt.ColumnBytes(i, buf)
|
|
row[columnName] = blob
|
|
case sqlite.TypeNull:
|
|
row[columnName] = nil
|
|
}
|
|
}
|
|
|
|
// Add row copy to results
|
|
rowCopy := make(map[string]any, len(row))
|
|
for k, v := range row {
|
|
rowCopy[k] = v
|
|
}
|
|
rows = append(rows, rowCopy)
|
|
return nil
|
|
},
|
|
})
|
|
|
|
if err != nil {
|
|
state.PushString("sqlite.query: " + err.Error())
|
|
return -1
|
|
}
|
|
|
|
// Create result table
|
|
state.NewTable()
|
|
|
|
// Add results to the table
|
|
for i, row := range rows {
|
|
state.PushNumber(float64(i + 1))
|
|
if err := state.PushTable(row); err != nil {
|
|
state.PushString("sqlite.query: " + err.Error())
|
|
return -1
|
|
}
|
|
state.SetTable(-3)
|
|
}
|
|
|
|
return 1
|
|
}
|
|
|
|
// luaSQLExec executes a SQL statement without returning results
|
|
func luaSQLExec(state *luajit.State) int {
|
|
// Get database name and query
|
|
if !state.IsString(1) {
|
|
state.PushString("sqlite.exec: database name must be a string")
|
|
return -1
|
|
}
|
|
dbName := state.ToString(1)
|
|
|
|
if !state.IsString(2) {
|
|
state.PushString("sqlite.exec: query must be a string")
|
|
return -1
|
|
}
|
|
query := state.ToString(2)
|
|
|
|
// Get connection ID (optional for compatibility)
|
|
var connID string
|
|
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
|
|
connID = state.ToString(4)
|
|
} else {
|
|
// Generate a temporary connection ID
|
|
connID = fmt.Sprintf("temp_%p", &query)
|
|
}
|
|
|
|
// Get parameters (optional)
|
|
var params map[string]any
|
|
if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
|
|
var err error
|
|
params, err = state.ToTable(3)
|
|
if err != nil {
|
|
state.PushString("sqlite.exec: failed to parse parameters: " + err.Error())
|
|
return -1
|
|
}
|
|
}
|
|
|
|
// Get connection
|
|
conn, pool, err := getConnection(dbName, connID)
|
|
if err != nil {
|
|
state.PushString("sqlite.exec: " + err.Error())
|
|
return -1
|
|
}
|
|
|
|
// For temporary connections, defer release
|
|
if !strings.HasPrefix(connID, "temp_") {
|
|
defer pool.Put(conn)
|
|
|
|
// Remove from active connections
|
|
sqliteManager.mu.Lock()
|
|
delete(sqliteManager.activeConns, connID)
|
|
sqliteManager.mu.Unlock()
|
|
}
|
|
|
|
// Execute statement
|
|
if params != nil {
|
|
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
|
|
Named: params, // Using Named for named parameters
|
|
})
|
|
} else {
|
|
err = sqlitex.ExecScript(conn, query)
|
|
}
|
|
|
|
if err != nil {
|
|
state.PushString("sqlite.exec: " + err.Error())
|
|
return -1
|
|
}
|
|
|
|
// Return number of affected rows
|
|
state.PushNumber(float64(conn.Changes()))
|
|
return 1
|
|
}
|
|
|
|
// RegisterSQLiteFunctions registers SQLite functions with the Lua state
|
|
func RegisterSQLiteFunctions(state *luajit.State) error {
|
|
if err := state.RegisterGoFunction("__sqlite_query", luaSQLQuery); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := state.RegisterGoFunction("__sqlite_exec", luaSQLExec); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|