Moonshark/core/runner/sqlite.go

642 lines
16 KiB
Go

package runner
import (
"context"
"errors"
"fmt"
"path/filepath"
"strings"
"sync"
sqlite "zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
"Moonshark/core/utils/logger"
"maps"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// SQLiteConnection tracks an active connection
type SQLiteConnection struct {
DbName string
Conn *sqlite.Conn
Pool *sqlitex.Pool
}
// SQLiteManager handles database connections
type SQLiteManager struct {
mu sync.RWMutex
pools map[string]*sqlitex.Pool
activeConns map[string]*SQLiteConnection
dataDir string
}
// Global manager
var sqliteManager *SQLiteManager
// InitSQLite initializes the SQLite manager
func InitSQLite(dataDir string) {
sqliteManager = &SQLiteManager{
pools: make(map[string]*sqlitex.Pool),
activeConns: make(map[string]*SQLiteConnection),
dataDir: dataDir,
}
logger.Server("SQLite initialized with data directory: %s", dataDir)
}
// CleanupSQLite closes all database connections
func CleanupSQLite() {
if sqliteManager == nil {
return
}
sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock()
// Release all active connections
for id, conn := range sqliteManager.activeConns {
if conn.Pool != nil {
conn.Pool.Put(conn.Conn)
}
delete(sqliteManager.activeConns, id)
}
// Close all pools
for name, pool := range sqliteManager.pools {
if err := pool.Close(); err != nil {
logger.Error("Failed to close database %s: %v", name, err)
}
}
sqliteManager.pools = nil
sqliteManager.activeConns = nil
logger.Debug("SQLite connections closed")
}
// ReleaseActiveConnections returns all active connections to their pools
func ReleaseActiveConnections(state *luajit.State) {
if sqliteManager == nil {
return
}
sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock()
// Get active connections table from Lua
state.GetGlobal("__active_sqlite_connections")
if !state.IsTable(-1) {
state.Pop(1)
return
}
// Iterate through active connections
state.PushNil() // Start iteration
for state.Next(-2) {
// Stack now has key at -2 and value at -1
if state.IsTable(-1) {
state.GetField(-1, "id")
if state.IsString(-1) {
connID := state.ToString(-1)
// Release connection from Go side
if conn, exists := sqliteManager.activeConns[connID]; exists {
if conn.Pool != nil {
conn.Pool.Put(conn.Conn)
}
delete(sqliteManager.activeConns, connID)
}
}
state.Pop(1) // Pop connection id
}
state.Pop(1) // Pop value, leave key for next iteration
}
// Clear the active connections table
state.PushNil()
state.SetGlobal("__active_sqlite_connections")
}
// getPool returns a connection pool for the specified database
func getPool(dbName string) (*sqlitex.Pool, error) {
if sqliteManager == nil {
return nil, errors.New("SQLite not initialized")
}
// Validate database name
dbName = filepath.Base(dbName)
if dbName == "" || dbName[0] == '.' {
return nil, errors.New("invalid database name")
}
// Check for existing pool
sqliteManager.mu.RLock()
pool, exists := sqliteManager.pools[dbName]
sqliteManager.mu.RUnlock()
if exists {
return pool, nil
}
// Create new pool
sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock()
// Double check if another goroutine created it
if pool, exists = sqliteManager.pools[dbName]; exists {
return pool, nil
}
// Create database file path
dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db")
// Create the pool
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{})
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
sqliteManager.pools[dbName] = pool
return pool, nil
}
// getConnection returns a connection from the pool
func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, error) {
// Check for existing connection first
sqliteManager.mu.RLock()
conn, exists := sqliteManager.activeConns[connID]
sqliteManager.mu.RUnlock()
if exists {
return conn.Conn, conn.Pool, nil
}
// Get the pool
pool, err := getPool(dbName)
if err != nil {
return nil, nil, err
}
// Get a connection using the newer Take API
dbConn, err := pool.Take(context.Background())
if err != nil {
return nil, nil, fmt.Errorf("failed to get connection from pool: %w", err)
}
// Store connection
sqliteManager.mu.Lock()
sqliteManager.activeConns[connID] = &SQLiteConnection{
DbName: dbName,
Conn: dbConn,
Pool: pool,
}
sqliteManager.mu.Unlock()
return dbConn, pool, nil
}
// detectParamType determines if parameters are positional or named
func detectParamType(params any) (isArray bool) {
if params == nil {
return false
}
// Check if it's a map[string]any
if paramsMap, ok := params.(map[string]any); ok {
// Check for the empty string key which indicates an array
if array, hasArray := paramsMap[""]; hasArray {
// Verify it's actually an array
if _, isSlice := array.([]any); isSlice {
return true
}
if _, isFloatSlice := array.([]float64); isFloatSlice {
return true
}
}
return false
}
// If it's already a slice type
if _, ok := params.([]any); ok {
return true
}
if _, ok := params.([]float64); ok {
return true
}
return false
}
// prepareParams processes parameters for SQLite queries
func prepareParams(params any) (map[string]any, []any) {
if params == nil {
return nil, nil
}
// Handle positional parameters (array-like)
if detectParamType(params) {
var positional []any
// Extract array from special map format
if paramsMap, ok := params.(map[string]any); ok {
if array, hasArray := paramsMap[""]; hasArray {
if slice, ok := array.([]any); ok {
positional = slice
} else if floatSlice, ok := array.([]float64); ok {
// Convert []float64 to []any
positional = make([]any, len(floatSlice))
for i, v := range floatSlice {
positional[i] = v
}
}
}
} else if slice, ok := params.([]any); ok {
positional = slice
} else if floatSlice, ok := params.([]float64); ok {
// Convert []float64 to []any
positional = make([]any, len(floatSlice))
for i, v := range floatSlice {
positional[i] = v
}
}
return nil, positional
}
// Handle named parameters (map-like)
if paramsMap, ok := params.(map[string]any); ok {
modified := make(map[string]any, len(paramsMap))
for key, value := range paramsMap {
if len(key) > 0 && key[0] != ':' {
modified[":"+key] = value
} else {
modified[key] = value
}
}
return modified, nil
}
return nil, nil
}
// luaSQLQuery executes a SQL query and returns results to Lua
func luaSQLQuery(state *luajit.State) int {
// Get database name
if !state.IsString(1) {
state.PushString("sqlite.query: database name must be a string")
return -1
}
dbName := state.ToString(1)
// Get query
if !state.IsString(2) {
state.PushString("sqlite.query: query must be a string")
return -1
}
query := state.ToString(2)
// Check if using positional parameters
isPositional := false
var positionalParams []any
// Get connection ID (optional)
var connID string
// Check if we have positional parameters instead of a params table
if state.GetTop() >= 3 && !state.IsTable(3) {
isPositional = true
paramCount := state.GetTop() - 2 // Count all args after db and query
// Adjust connection ID index if we have positional params
if paramCount > 0 {
// Last parameter might be connID if it's a string
lastIdx := paramCount + 2 // db(1) + query(2) + paramCount
if state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) {
connID = state.ToString(lastIdx)
paramCount-- // Exclude connID from param count
}
}
// Create array for positional parameters
positionalParams = make([]any, paramCount)
// Collect all parameters
for i := 0; i < paramCount; i++ {
paramIdx := i + 3 // Params start at index 3
// Convert to appropriate Go value
var value any
switch state.GetType(paramIdx) {
case luajit.TypeNumber:
value = state.ToNumber(paramIdx)
case luajit.TypeString:
value = state.ToString(paramIdx)
case luajit.TypeBoolean:
value = state.ToBoolean(paramIdx)
case luajit.TypeNil:
value = nil
default:
// Try to convert as generic value
var err error
value, err = state.ToValue(paramIdx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: failed to convert parameter %d: %s", i+1, err.Error()))
return -1
}
}
positionalParams[i] = value
}
} else {
// Original named parameter table handling
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
connID = state.ToString(4)
} else {
// Generate a temporary connection ID
connID = fmt.Sprintf("temp_%p", &query)
}
}
// Get parameters (optional for named parameters)
var params any
if !isPositional && state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
var err error
params, err = state.ToTable(3)
if err != nil {
state.PushString("sqlite.query: failed to parse parameters: " + err.Error())
return -1
}
}
// Get connection
conn, pool, err := getConnection(dbName, connID)
if err != nil {
state.PushString("sqlite.query: " + err.Error())
return -1
}
// For temporary connections, defer release
if strings.HasPrefix(connID, "temp_") {
defer func() {
// Release the connection
sqliteManager.mu.Lock()
delete(sqliteManager.activeConns, connID)
sqliteManager.mu.Unlock()
pool.Put(conn)
}()
}
// Execute query and collect results
var rows []map[string]any
// Prepare params based on type
namedParams, positional := prepareParams(params)
// If we have direct positional params from function args, use those
if isPositional {
positional = positionalParams
}
// Count actual placeholders in the query
placeholderCount := strings.Count(query, "?")
// Execute with appropriate parameter type
execOpts := &sqlitex.ExecOptions{
ResultFunc: func(stmt *sqlite.Stmt) error {
row := make(map[string]any)
columnCount := stmt.ColumnCount()
for i := range columnCount {
columnName := stmt.ColumnName(i)
columnType := stmt.ColumnType(i)
switch columnType {
case sqlite.TypeInteger:
row[columnName] = stmt.ColumnInt64(i)
case sqlite.TypeFloat:
row[columnName] = stmt.ColumnFloat(i)
case sqlite.TypeText:
row[columnName] = stmt.ColumnText(i)
case sqlite.TypeBlob:
blobSize := stmt.ColumnLen(i)
buf := make([]byte, blobSize)
blob := stmt.ColumnBytes(i, buf)
row[columnName] = blob
case sqlite.TypeNull:
row[columnName] = nil
}
}
// Add row copy to results
rowCopy := make(map[string]any, len(row))
maps.Copy(rowCopy, row)
rows = append(rows, rowCopy)
return nil
},
}
// Set appropriate parameter type
if namedParams != nil {
execOpts.Named = namedParams
} else if positional != nil {
// Make sure we're not passing more positional parameters than placeholders
if len(positional) > placeholderCount {
positional = positional[:placeholderCount]
}
execOpts.Args = positional
}
err = sqlitex.Execute(conn, query, execOpts)
if err != nil {
state.PushString("sqlite.query: " + err.Error())
return -1
}
// Create result table
state.NewTable()
// Add results to the table
for i, row := range rows {
state.PushNumber(float64(i + 1))
if err := state.PushTable(row); err != nil {
state.PushString("sqlite.query: " + err.Error())
return -1
}
state.SetTable(-3)
}
return 1
}
// luaSQLExec executes a SQL statement without returning results
func luaSQLExec(state *luajit.State) int {
// Get database name and query
if !state.IsString(1) {
state.PushString("sqlite.exec: database name must be a string")
return -1
}
dbName := state.ToString(1)
if !state.IsString(2) {
state.PushString("sqlite.exec: query must be a string")
return -1
}
query := state.ToString(2)
// Check if using positional parameters
isPositional := false
var positionalParams []any
// Get connection ID (optional)
var connID string
// Check if we have positional parameters instead of a params table
if state.GetTop() >= 3 && !state.IsTable(3) {
isPositional = true
paramCount := state.GetTop() - 2 // Count all args after db and query
// Adjust connection ID index if we have positional params
if paramCount > 0 {
// Last parameter might be connID if it's a string
lastIdx := paramCount + 2 // db(1) + query(2) + paramCount
if state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) {
connID = state.ToString(lastIdx)
paramCount-- // Exclude connID from param count
}
}
// Create array for positional parameters
positionalParams = make([]any, paramCount)
// Collect all parameters
for i := 0; i < paramCount; i++ {
paramIdx := i + 3 // Params start at index 3
// Convert to appropriate Go value
var value any
switch state.GetType(paramIdx) {
case luajit.TypeNumber:
value = state.ToNumber(paramIdx)
case luajit.TypeString:
value = state.ToString(paramIdx)
case luajit.TypeBoolean:
value = state.ToBoolean(paramIdx)
case luajit.TypeNil:
value = nil
default:
// Try to convert as generic value
var err error
value, err = state.ToValue(paramIdx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: failed to convert parameter %d: %s", i+1, err.Error()))
return -1
}
}
positionalParams[i] = value
}
} else {
// Original named parameter table handling
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
connID = state.ToString(4)
} else {
// Generate a temporary connection ID
connID = fmt.Sprintf("temp_%p", &query)
}
}
// Get parameters (optional for named parameters)
var params any
if !isPositional && state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
var err error
params, err = state.ToTable(3)
if err != nil {
state.PushString("sqlite.exec: failed to parse parameters: " + err.Error())
return -1
}
}
// Get connection
conn, pool, err := getConnection(dbName, connID)
if err != nil {
state.PushString("sqlite.exec: " + err.Error())
return -1
}
// For temporary connections, defer release
if strings.HasPrefix(connID, "temp_") {
defer func() {
// Release the connection
sqliteManager.mu.Lock()
delete(sqliteManager.activeConns, connID)
sqliteManager.mu.Unlock()
pool.Put(conn)
}()
}
// Count actual placeholders in the query
placeholderCount := strings.Count(query, "?")
// Prepare params based on type
namedParams, positional := prepareParams(params)
// If we have direct positional params from function args, use those
if isPositional {
positional = positionalParams
}
// Ensure we don't pass more parameters than placeholders
if positional != nil && len(positional) > placeholderCount {
positional = positional[:placeholderCount]
}
// Execute with appropriate parameter type
var execErr error
if isPositional || positional != nil {
// Execute with positional parameters
execOpts := &sqlitex.ExecOptions{
Args: positional,
}
execErr = sqlitex.Execute(conn, query, execOpts)
} else if namedParams != nil {
// Execute with named parameters
execOpts := &sqlitex.ExecOptions{
Named: namedParams,
}
execErr = sqlitex.Execute(conn, query, execOpts)
} else {
// Execute without parameters
execErr = sqlitex.ExecScript(conn, query)
}
if execErr != nil {
state.PushString("sqlite.exec: " + execErr.Error())
return -1
}
// Return number of affected rows
state.PushNumber(float64(conn.Changes()))
return 1
}
// RegisterSQLiteFunctions registers SQLite functions with the Lua state
func RegisterSQLiteFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__sqlite_query", luaSQLQuery); err != nil {
return err
}
if err := state.RegisterGoFunction("__sqlite_exec", luaSQLExec); err != nil {
return err
}
return nil
}