Moonshark/runner/sqlite.go

442 lines
11 KiB
Go

package runner
import (
"context"
"errors"
"fmt"
"path/filepath"
"strings"
"sync"
sqlite "zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
"Moonshark/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// SQLiteConnection tracks an active connection
type SQLiteConnection struct {
Conn *sqlite.Conn
Pool *sqlitex.Pool
}
// SQLiteManager handles database connections
type SQLiteManager struct {
mu sync.RWMutex
pools map[string]*sqlitex.Pool
activeConns map[string]*SQLiteConnection
dataDir string
}
var sqliteManager *SQLiteManager
// InitSQLite initializes the SQLite manager
func InitSQLite(dataDir string) {
sqliteManager = &SQLiteManager{
pools: make(map[string]*sqlitex.Pool),
activeConns: make(map[string]*SQLiteConnection),
dataDir: dataDir,
}
logger.Server("SQLite initialized with data directory: %s", dataDir)
}
// CleanupSQLite closes all database connections
func CleanupSQLite() {
if sqliteManager == nil {
return
}
sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock()
for id, conn := range sqliteManager.activeConns {
if conn.Pool != nil {
conn.Pool.Put(conn.Conn)
}
delete(sqliteManager.activeConns, id)
}
for name, pool := range sqliteManager.pools {
if err := pool.Close(); err != nil {
logger.Error("Failed to close database %s: %v", name, err)
}
}
sqliteManager.pools = nil
sqliteManager.activeConns = nil
logger.Debug("SQLite connections closed")
}
// ReleaseActiveConnections returns all active connections to their pools
func ReleaseActiveConnections(state *luajit.State) {
if sqliteManager == nil {
return
}
// Get active connections table from Lua
state.GetGlobal("__active_sqlite_connections")
if !state.IsTable(-1) {
state.Pop(1)
return
}
sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock()
// Iterate through active connections
state.PushNil() // Start iteration
for state.Next(-2) {
if state.IsTable(-1) {
state.GetField(-1, "id")
if state.IsString(-1) {
connID := state.ToString(-1)
if conn, exists := sqliteManager.activeConns[connID]; exists {
if conn.Pool != nil {
conn.Pool.Put(conn.Conn)
}
delete(sqliteManager.activeConns, connID)
}
}
state.Pop(1) // Pop connection id
}
state.Pop(1) // Pop value, leave key for next iteration
}
// Clear the active connections table
state.PushNil()
state.SetGlobal("__active_sqlite_connections")
}
// getConnection returns a connection for the database
func getConnection(dbName, connID string) (*sqlite.Conn, error) {
if sqliteManager == nil {
return nil, errors.New("SQLite not initialized")
}
// Validate database name
dbName = filepath.Base(dbName)
if dbName == "" || dbName[0] == '.' {
return nil, errors.New("invalid database name")
}
// Check for existing connection
sqliteManager.mu.RLock()
conn, exists := sqliteManager.activeConns[connID]
if exists {
sqliteManager.mu.RUnlock()
return conn.Conn, nil
}
sqliteManager.mu.RUnlock()
// Get or create pool under write lock
sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock()
// Double-check if a connection was created while waiting for lock
if conn, exists = sqliteManager.activeConns[connID]; exists {
return conn.Conn, nil
}
// Get or create pool
pool, exists := sqliteManager.pools[dbName]
if !exists {
dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db")
var err error
pool, err = sqlitex.NewPool(dbPath, sqlitex.PoolOptions{})
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
sqliteManager.pools[dbName] = pool
}
// Get a connection
dbConn, err := pool.Take(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get connection from pool: %w", err)
}
// Store connection
sqliteManager.activeConns[connID] = &SQLiteConnection{
Conn: dbConn,
Pool: pool,
}
return dbConn, nil
}
// releaseConnection returns a connection to its pool
func releaseConnection(connID string) {
if sqliteManager == nil {
return
}
sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock()
conn, exists := sqliteManager.activeConns[connID]
if !exists {
return
}
if conn.Pool != nil {
conn.Pool.Put(conn.Conn)
}
delete(sqliteManager.activeConns, connID)
}
// sqlQuery executes a SQL query and returns results
func sqlQuery(state *luajit.State) int {
// Get required parameters
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
state.PushString("sqlite.query: requires database name and query")
return -1
}
dbName := state.ToString(1)
query := state.ToString(2)
connID := fmt.Sprintf("temp_%p", &query)
// Get connection
conn, err := getConnection(dbName, connID)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
defer releaseConnection(connID)
// Create execution options
var execOpts sqlitex.ExecOptions
rows := make([]map[string]any, 0, 16)
// Set up parameters if provided
if state.GetTop() >= 3 && !state.IsNil(3) {
if state.IsTable(3) {
params, err := state.ToTable(3)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: invalid parameters: %s", err.Error()))
return -1
}
// Check for array-style params
if arr, ok := params[""]; ok {
if arrParams, ok := arr.([]any); ok {
execOpts.Args = arrParams
} else if floatArr, ok := arr.([]float64); ok {
args := make([]any, len(floatArr))
for i, v := range floatArr {
args[i] = v
}
execOpts.Args = args
}
} else {
// Named parameters
named := make(map[string]any, len(params))
for k, v := range params {
if len(k) > 0 && k[0] != ':' {
named[":"+k] = v
} else {
named[k] = v
}
}
execOpts.Named = named
}
} else {
// Positional parameters
count := state.GetTop() - 2
args := make([]any, count)
for i := range count {
idx := i + 3
switch state.GetType(idx) {
case luajit.TypeNumber:
args[i] = state.ToNumber(idx)
case luajit.TypeString:
args[i] = state.ToString(idx)
case luajit.TypeBoolean:
args[i] = state.ToBoolean(idx)
case luajit.TypeNil:
args[i] = nil
default:
val, err := state.ToValue(idx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: invalid parameter %d: %s", i+1, err.Error()))
return -1
}
args[i] = val
}
}
execOpts.Args = args
}
}
// Set up result function
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
row := make(map[string]any)
colCount := stmt.ColumnCount()
for i := range colCount {
colName := stmt.ColumnName(i)
switch stmt.ColumnType(i) {
case sqlite.TypeInteger:
row[colName] = stmt.ColumnInt64(i)
case sqlite.TypeFloat:
row[colName] = stmt.ColumnFloat(i)
case sqlite.TypeText:
row[colName] = stmt.ColumnText(i)
case sqlite.TypeBlob:
blobSize := stmt.ColumnLen(i)
buf := make([]byte, blobSize)
row[colName] = stmt.ColumnBytes(i, buf)
case sqlite.TypeNull:
row[colName] = nil
}
}
rows = append(rows, row)
return nil
}
// Execute query
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
// Create result table
state.NewTable()
for i, row := range rows {
state.PushNumber(float64(i + 1))
if err := state.PushTable(row); err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
state.SetTable(-3)
}
return 1
}
// sqlExec executes a SQL statement without returning results
func sqlExec(state *luajit.State) int {
// Get required parameters
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
state.PushString("sqlite.exec: requires database name and query")
return -1
}
dbName := state.ToString(1)
query := state.ToString(2)
connID := fmt.Sprintf("temp_%p", &query)
// Get connection
conn, err := getConnection(dbName, connID)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
defer releaseConnection(connID)
// Check if parameters are provided
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
// Fast path for multi-statement scripts - use ExecScript
if strings.Contains(query, ";") && !hasParams {
if err := sqlitex.ExecScript(conn, query); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
state.PushNumber(float64(conn.Changes()))
return 1
}
// Fast path for simple queries with no parameters
if !hasParams {
// Use Execute for simple statements without parameters
if err := sqlitex.Execute(conn, query, nil); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
state.PushNumber(float64(conn.Changes()))
return 1
}
// Create execution options for parameterized query
var execOpts sqlitex.ExecOptions
// Set up parameters
if state.IsTable(3) {
params, err := state.ToTable(3)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: invalid parameters: %s", err.Error()))
return -1
}
// Check for array-style params
if arr, ok := params[""]; ok {
if arrParams, ok := arr.([]any); ok {
execOpts.Args = arrParams
} else if floatArr, ok := arr.([]float64); ok {
args := make([]any, len(floatArr))
for i, v := range floatArr {
args[i] = v
}
execOpts.Args = args
}
} else {
// Named parameters
named := make(map[string]any, len(params))
for k, v := range params {
if len(k) > 0 && k[0] != ':' {
named[":"+k] = v
} else {
named[k] = v
}
}
execOpts.Named = named
}
} else {
// Positional parameters
count := state.GetTop() - 2
args := make([]any, count)
for i := range count {
idx := i + 3
switch state.GetType(idx) {
case luajit.TypeNumber:
args[i] = state.ToNumber(idx)
case luajit.TypeString:
args[i] = state.ToString(idx)
case luajit.TypeBoolean:
args[i] = state.ToBoolean(idx)
case luajit.TypeNil:
args[i] = nil
default:
val, err := state.ToValue(idx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: invalid parameter %d: %s", i+1, err.Error()))
return -1
}
args[i] = val
}
}
execOpts.Args = args
}
// Execute with parameters
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
// Return affected rows
state.PushNumber(float64(conn.Changes()))
return 1
}
// RegisterSQLiteFunctions registers SQLite functions with the Lua state
func RegisterSQLiteFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
return err
}
return state.RegisterGoFunction("__sqlite_exec", sqlExec)
}