package runner import ( "context" "fmt" "path/filepath" "strings" "sync" sqlite "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" "Moonshark/utils/logger" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) // DbPools maintains database connection pools var ( dbPools = make(map[string]*sqlitex.Pool) poolsMu sync.RWMutex dataDir string ) // InitSQLite initializes the SQLite subsystem func InitSQLite(dir string) { dataDir = dir logger.Server("SQLite initialized with data directory: %s", dir) } // CleanupSQLite closes all database connections func CleanupSQLite() { poolsMu.Lock() defer poolsMu.Unlock() for name, pool := range dbPools { if err := pool.Close(); err != nil { logger.Error("Failed to close database %s: %v", name, err) } } dbPools = make(map[string]*sqlitex.Pool) logger.Debug("SQLite connections closed") } // getPool returns a connection pool for the database func getPool(dbName string) (*sqlitex.Pool, error) { // Validate database name dbName = filepath.Base(dbName) if dbName == "" || dbName[0] == '.' { return nil, fmt.Errorf("invalid database name") } // Check for existing pool poolsMu.RLock() pool, exists := dbPools[dbName] if exists { poolsMu.RUnlock() return pool, nil } poolsMu.RUnlock() // Create new pool under write lock poolsMu.Lock() defer poolsMu.Unlock() // Double-check if a pool was created while waiting for lock if pool, exists = dbPools[dbName]; exists { return pool, nil } // Create new pool dbPath := filepath.Join(dataDir, dbName+".db") pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{}) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } dbPools[dbName] = pool return pool, nil } // sqlQuery executes a SQL query and returns results func sqlQuery(state *luajit.State) int { // Get required parameters if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) { state.PushString("sqlite.query: requires database name and query") return -1 } dbName := state.ToString(1) query := state.ToString(2) // Get connection pool pool, err := getPool(dbName) if err != nil { state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) return -1 } // Get a connection from the pool conn, err := pool.Take(context.Background()) if err != nil { state.PushString(fmt.Sprintf("sqlite.query: failed to get connection: %s", err.Error())) return -1 } defer pool.Put(conn) // Create execution options var execOpts sqlitex.ExecOptions rows := make([]map[string]any, 0, 16) // Set up parameters if provided if state.GetTop() >= 3 && !state.IsNil(3) { if state.IsTable(3) { params, err := state.ToTable(3) if err != nil { state.PushString(fmt.Sprintf("sqlite.query: invalid parameters: %s", err.Error())) return -1 } // Check for array-style params if arr, ok := params[""]; ok { if arrParams, ok := arr.([]any); ok { execOpts.Args = arrParams } else if floatArr, ok := arr.([]float64); ok { args := make([]any, len(floatArr)) for i, v := range floatArr { args[i] = v } execOpts.Args = args } } else { // Named parameters named := make(map[string]any, len(params)) for k, v := range params { if len(k) > 0 && k[0] != ':' { named[":"+k] = v } else { named[k] = v } } execOpts.Named = named } } else { // Positional parameters count := state.GetTop() - 2 args := make([]any, count) for i := range count { idx := i + 3 switch state.GetType(idx) { case luajit.TypeNumber: args[i] = state.ToNumber(idx) case luajit.TypeString: args[i] = state.ToString(idx) case luajit.TypeBoolean: args[i] = state.ToBoolean(idx) case luajit.TypeNil: args[i] = nil default: val, err := state.ToValue(idx) if err != nil { state.PushString(fmt.Sprintf("sqlite.query: invalid parameter %d: %s", i+1, err.Error())) return -1 } args[i] = val } } execOpts.Args = args } } // Set up result function execOpts.ResultFunc = func(stmt *sqlite.Stmt) error { row := make(map[string]any) colCount := stmt.ColumnCount() for i := range colCount { colName := stmt.ColumnName(i) switch stmt.ColumnType(i) { case sqlite.TypeInteger: row[colName] = stmt.ColumnInt64(i) case sqlite.TypeFloat: row[colName] = stmt.ColumnFloat(i) case sqlite.TypeText: row[colName] = stmt.ColumnText(i) case sqlite.TypeBlob: blobSize := stmt.ColumnLen(i) buf := make([]byte, blobSize) row[colName] = stmt.ColumnBytes(i, buf) case sqlite.TypeNull: row[colName] = nil } } rows = append(rows, row) return nil } // Execute query if err := sqlitex.Execute(conn, query, &execOpts); err != nil { state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) return -1 } // Create result table state.NewTable() for i, row := range rows { state.PushNumber(float64(i + 1)) if err := state.PushTable(row); err != nil { state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) return -1 } state.SetTable(-3) } return 1 } // sqlExec executes a SQL statement without returning results func sqlExec(state *luajit.State) int { // Get required parameters if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) { state.PushString("sqlite.exec: requires database name and query") return -1 } dbName := state.ToString(1) query := state.ToString(2) // Get connection pool pool, err := getPool(dbName) if err != nil { state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) return -1 } // Get a connection from the pool conn, err := pool.Take(context.Background()) if err != nil { state.PushString(fmt.Sprintf("sqlite.exec: failed to get connection: %s", err.Error())) return -1 } defer pool.Put(conn) // Check if parameters are provided hasParams := state.GetTop() >= 3 && !state.IsNil(3) // Fast path for multi-statement scripts - use ExecScript if strings.Contains(query, ";") && !hasParams { if err := sqlitex.ExecScript(conn, query); err != nil { state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) return -1 } state.PushNumber(float64(conn.Changes())) return 1 } // Fast path for simple queries with no parameters if !hasParams { // Use Execute for simple statements without parameters if err := sqlitex.Execute(conn, query, nil); err != nil { state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) return -1 } state.PushNumber(float64(conn.Changes())) return 1 } // Create execution options for parameterized query var execOpts sqlitex.ExecOptions // Set up parameters if state.IsTable(3) { params, err := state.ToTable(3) if err != nil { state.PushString(fmt.Sprintf("sqlite.exec: invalid parameters: %s", err.Error())) return -1 } // Check for array-style params if arr, ok := params[""]; ok { if arrParams, ok := arr.([]any); ok { execOpts.Args = arrParams } else if floatArr, ok := arr.([]float64); ok { args := make([]any, len(floatArr)) for i, v := range floatArr { args[i] = v } execOpts.Args = args } } else { // Named parameters named := make(map[string]any, len(params)) for k, v := range params { if len(k) > 0 && k[0] != ':' { named[":"+k] = v } else { named[k] = v } } execOpts.Named = named } } else { // Positional parameters count := state.GetTop() - 2 args := make([]any, count) for i := range count { idx := i + 3 switch state.GetType(idx) { case luajit.TypeNumber: args[i] = state.ToNumber(idx) case luajit.TypeString: args[i] = state.ToString(idx) case luajit.TypeBoolean: args[i] = state.ToBoolean(idx) case luajit.TypeNil: args[i] = nil default: val, err := state.ToValue(idx) if err != nil { state.PushString(fmt.Sprintf("sqlite.exec: invalid parameter %d: %s", i+1, err.Error())) return -1 } args[i] = val } } execOpts.Args = args } // Execute with parameters if err := sqlitex.Execute(conn, query, &execOpts); err != nil { state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) return -1 } // Return affected rows state.PushNumber(float64(conn.Changes())) return 1 } // RegisterSQLiteFunctions registers SQLite functions with the Lua state func RegisterSQLiteFunctions(state *luajit.State) error { if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil { return err } return state.RegisterGoFunction("__sqlite_exec", sqlExec) }