Moonshark/runner/sqlite.go
2025-05-31 09:11:21 -05:00

428 lines
10 KiB
Go

package runner
import (
"context"
"fmt"
"path/filepath"
"strings"
"sync"
"time"
sqlite "zombiezen.com/go/sqlite"
"zombiezen.com/go/sqlite/sqlitex"
"Moonshark/utils/color"
"Moonshark/utils/logger"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
var (
dbPools = make(map[string]*sqlitex.Pool)
poolsMu sync.RWMutex
dataDir string
poolSize = 8 // Default, will be set to match runner pool size
connTimeout = 5 * time.Second
)
// InitSQLite initializes the SQLite subsystem
func InitSQLite(dir string) {
dataDir = dir
logger.Info("SQLite is g2g! %s", color.Yellow(dir))
}
// SetSQLitePoolSize sets the pool size to match the runner pool size
func SetSQLitePoolSize(size int) {
if size > 0 {
poolSize = size
}
}
// CleanupSQLite closes all database connections
func CleanupSQLite() {
poolsMu.Lock()
defer poolsMu.Unlock()
for name, pool := range dbPools {
if err := pool.Close(); err != nil {
logger.Error("Failed to close database %s: %v", name, err)
}
}
dbPools = make(map[string]*sqlitex.Pool)
logger.Debug("SQLite connections closed")
}
// getPool returns a connection pool for the database
func getPool(dbName string) (*sqlitex.Pool, error) {
// Validate database name
dbName = filepath.Base(dbName)
if dbName == "" || dbName[0] == '.' {
return nil, fmt.Errorf("invalid database name")
}
// Check for existing pool
poolsMu.RLock()
pool, exists := dbPools[dbName]
if exists {
poolsMu.RUnlock()
return pool, nil
}
poolsMu.RUnlock()
// Create new pool under write lock
poolsMu.Lock()
defer poolsMu.Unlock()
// Double-check if a pool was created while waiting for lock
if pool, exists = dbPools[dbName]; exists {
return pool, nil
}
// Create new pool with proper size
dbPath := filepath.Join(dataDir, dbName+".db")
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
PoolSize: poolSize,
PrepareConn: func(conn *sqlite.Conn) error {
// Execute PRAGMA statements individually
pragmas := []string{
"PRAGMA journal_mode = WAL",
"PRAGMA synchronous = NORMAL",
"PRAGMA cache_size = 1000",
"PRAGMA foreign_keys = ON",
"PRAGMA temp_store = MEMORY",
}
for _, pragma := range pragmas {
if err := sqlitex.ExecuteTransient(conn, pragma, nil); err != nil {
return err
}
}
return nil
},
})
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
dbPools[dbName] = pool
logger.Debug("Created SQLite pool for %s (size: %d)", dbName, poolSize)
return pool, nil
}
// sqlQuery executes a SQL query and returns results
func sqlQuery(state *luajit.State) int {
// Get required parameters
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
state.PushString("sqlite.query: requires database name and query")
return -1
}
dbName := state.ToString(1)
query := state.ToString(2)
// Get pool
pool, err := getPool(dbName)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
// Get connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: connection timeout: %s", err.Error()))
return -1
}
defer pool.Put(conn)
// Create execution options
var execOpts sqlitex.ExecOptions
rows := make([]map[string]any, 0, 16)
// Set up parameters if provided
if state.GetTop() >= 3 && !state.IsNil(3) {
if err := setupParams(state, 3, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
}
// Set up result function
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
row := make(map[string]any)
colCount := stmt.ColumnCount()
for i := range colCount {
colName := stmt.ColumnName(i)
switch stmt.ColumnType(i) {
case sqlite.TypeInteger:
row[colName] = stmt.ColumnInt64(i)
case sqlite.TypeFloat:
row[colName] = stmt.ColumnFloat(i)
case sqlite.TypeText:
row[colName] = stmt.ColumnText(i)
case sqlite.TypeBlob:
blobSize := stmt.ColumnLen(i)
if blobSize > 0 {
buf := make([]byte, blobSize)
row[colName] = stmt.ColumnBytes(i, buf)
} else {
row[colName] = []byte{}
}
case sqlite.TypeNull:
row[colName] = nil
}
}
rows = append(rows, row)
return nil
}
// Execute query
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
// Create result table
state.NewTable()
for i, row := range rows {
state.PushNumber(float64(i + 1))
if err := state.PushTable(row); err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
state.SetTable(-3)
}
return 1
}
// sqlExec executes a SQL statement without returning results
func sqlExec(state *luajit.State) int {
// Get required parameters
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
state.PushString("sqlite.exec: requires database name and query")
return -1
}
dbName := state.ToString(1)
query := state.ToString(2)
// Get pool
pool, err := getPool(dbName)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
// Get connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: connection timeout: %s", err.Error()))
return -1
}
defer pool.Put(conn)
// Check if parameters are provided
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
// Fast path for multi-statement scripts
if strings.Contains(query, ";") && !hasParams {
if err := sqlitex.ExecScript(conn, query); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
state.PushNumber(float64(conn.Changes()))
return 1
}
// Fast path for simple queries with no parameters
if !hasParams {
if err := sqlitex.Execute(conn, query, nil); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
state.PushNumber(float64(conn.Changes()))
return 1
}
// Create execution options for parameterized query
var execOpts sqlitex.ExecOptions
if err := setupParams(state, 3, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
// Execute with parameters
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
// Return affected rows
state.PushNumber(float64(conn.Changes()))
return 1
}
// setupParams configures execution options with parameters from Lua
func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error {
if state.IsTable(paramIndex) {
params, err := state.ToTable(paramIndex)
if err != nil {
return fmt.Errorf("invalid parameters: %w", err)
}
// Check for array-style params
if arr, ok := params[""]; ok {
if arrParams, ok := arr.([]any); ok {
execOpts.Args = arrParams
} else if floatArr, ok := arr.([]float64); ok {
args := make([]any, len(floatArr))
for i, v := range floatArr {
args[i] = v
}
execOpts.Args = args
}
} else {
// Named parameters
named := make(map[string]any, len(params))
for k, v := range params {
if len(k) > 0 && k[0] != ':' {
named[":"+k] = v
} else {
named[k] = v
}
}
execOpts.Named = named
}
} else {
// Positional parameters from stack
count := state.GetTop() - 2
args := make([]any, count)
for i := range count {
idx := i + 3
val, err := state.ToValue(idx)
if err != nil {
return fmt.Errorf("invalid parameter %d: %w", i+1, err)
}
args[i] = val
}
execOpts.Args = args
}
return nil
}
// sqlGetOne executes a query and returns only the first row
func sqlGetOne(state *luajit.State) int {
// Get required parameters
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
state.PushString("sqlite.get_one: requires database name and query")
return -1
}
dbName := state.ToString(1)
query := state.ToString(2)
// Get pool
pool, err := getPool(dbName)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
}
// Get connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: connection timeout: %s", err.Error()))
return -1
}
defer pool.Put(conn)
// Create execution options
var execOpts sqlitex.ExecOptions
var result map[string]any
// Set up parameters if provided
if state.GetTop() >= 3 && !state.IsNil(3) {
if err := setupParams(state, 3, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
}
}
// Set up result function to get only first row
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
if result != nil {
return nil // Already got first row
}
result = make(map[string]any)
colCount := stmt.ColumnCount()
for i := range colCount {
colName := stmt.ColumnName(i)
switch stmt.ColumnType(i) {
case sqlite.TypeInteger:
result[colName] = stmt.ColumnInt64(i)
case sqlite.TypeFloat:
result[colName] = stmt.ColumnFloat(i)
case sqlite.TypeText:
result[colName] = stmt.ColumnText(i)
case sqlite.TypeBlob:
blobSize := stmt.ColumnLen(i)
if blobSize > 0 {
buf := make([]byte, blobSize)
result[colName] = stmt.ColumnBytes(i, buf)
} else {
result[colName] = []byte{}
}
case sqlite.TypeNull:
result[colName] = nil
}
}
return nil
}
// Execute query
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
}
// Return result or nil if no rows
if result == nil {
state.PushNil()
} else {
if err := state.PushTable(result); err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
}
}
return 1
}
// RegisterSQLiteFunctions registers SQLite functions with the Lua state
func RegisterSQLiteFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
return err
}
if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil {
return err
}
if err := state.RegisterGoFunction("__sqlite_get_one", sqlGetOne); err != nil {
return err
}
return nil
}