package sql import ( "context" "fmt" "sync" "time" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) // Driver interface for SQL database implementations type Driver interface { Open(dsn string) (Connection, error) Name() string } // Connection represents a database connection type Connection interface { Close() error Ping(ctx context.Context) error Begin(ctx context.Context) (Transaction, error) Query(ctx context.Context, query string, args ...any) (Rows, error) QueryRow(ctx context.Context, query string, args ...any) Row Exec(ctx context.Context, query string, args ...any) (Result, error) Prepare(ctx context.Context, query string) (Statement, error) } // Transaction represents a database transaction type Transaction interface { Commit() error Rollback() error Query(ctx context.Context, query string, args ...any) (Rows, error) QueryRow(ctx context.Context, query string, args ...any) Row Exec(ctx context.Context, query string, args ...any) (Result, error) Prepare(ctx context.Context, query string) (Statement, error) } // Rows represents query result rows type Rows interface { Next() bool Scan(dest ...any) error Columns() ([]string, error) Close() error Err() error } // Row represents a single query result row type Row interface { Scan(dest ...any) error } // Result represents the result of an executed statement type Result interface { LastInsertId() (int64, error) RowsAffected() (int64, error) } // Statement represents a prepared statement type Statement interface { Close() error Query(ctx context.Context, args ...any) (Rows, error) QueryRow(ctx context.Context, args ...any) Row Exec(ctx context.Context, args ...any) (Result, error) } // Registry manages database drivers and connections type Registry struct { mu sync.RWMutex drivers map[string]Driver conns map[string]Connection nextID int } var global = &Registry{ drivers: make(map[string]Driver), conns: make(map[string]Connection), } // RegisterDriver registers a database driver func RegisterDriver(name string, driver Driver) { global.mu.Lock() defer global.mu.Unlock() global.drivers[name] = driver } // GetDriver returns a registered driver func GetDriver(name string) (Driver, bool) { global.mu.RLock() defer global.mu.RUnlock() driver, exists := global.drivers[name] return driver, exists } // Connect opens a new database connection func Connect(driverName, dsn string) (string, error) { driver, exists := GetDriver(driverName) if !exists { return "", fmt.Errorf("unknown driver: %s", driverName) } conn, err := driver.Open(dsn) if err != nil { return "", err } global.mu.Lock() defer global.mu.Unlock() id := fmt.Sprintf("%s_%d", driverName, global.nextID) global.nextID++ global.conns[id] = conn return id, nil } // GetConnection retrieves a connection by ID func GetConnection(id string) (Connection, bool) { global.mu.RLock() defer global.mu.RUnlock() conn, exists := global.conns[id] return conn, exists } // CloseConnection closes and removes a connection func CloseConnection(id string) error { global.mu.Lock() defer global.mu.Unlock() conn, exists := global.conns[id] if !exists { return fmt.Errorf("connection not found: %s", id) } err := conn.Close() delete(global.conns, id) return err } // Lua function implementations func luaConnect(s *luajit.State) int { if err := s.CheckExactArgs(2); err != nil { return s.PushError("connect: %v", err) } driver, err := s.SafeToString(1) if err != nil { return s.PushError("connect: driver must be a string") } dsn, err := s.SafeToString(2) if err != nil { return s.PushError("connect: dsn must be a string") } connID, err := Connect(driver, dsn) if err != nil { return s.PushError("connect: %v", err) } s.PushString(connID) return 1 } func luaClose(s *luajit.State) int { if err := s.CheckExactArgs(1); err != nil { return s.PushError("close: %v", err) } connID, err := s.SafeToString(1) if err != nil { return s.PushError("close: connection id must be a string") } if err := CloseConnection(connID); err != nil { return s.PushError("close: %v", err) } s.PushBoolean(true) return 1 } func luaPing(s *luajit.State) int { if err := s.CheckExactArgs(1); err != nil { return s.PushError("ping: %v", err) } connID, err := s.SafeToString(1) if err != nil { return s.PushError("ping: connection id must be a string") } conn, exists := GetConnection(connID) if !exists { return s.PushError("ping: connection not found") } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := conn.Ping(ctx); err != nil { return s.PushError("ping: %v", err) } s.PushBoolean(true) return 1 } func luaQuery(s *luajit.State) int { if err := s.CheckMinArgs(2); err != nil { return s.PushError("query: %v", err) } connID, err := s.SafeToString(1) if err != nil { return s.PushError("query: connection id must be a string") } query, err := s.SafeToString(2) if err != nil { return s.PushError("query: query must be a string") } conn, exists := GetConnection(connID) if !exists { return s.PushError("query: connection not found") } // Collect arguments args := make([]any, s.GetTop()-2) for i := 3; i <= s.GetTop(); i++ { val, err := s.ToValue(i) if err != nil { args[i-3] = nil } else { args[i-3] = val } } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() rows, err := conn.Query(ctx, query, args...) if err != nil { return s.PushError("query: %v", err) } defer rows.Close() // Get column names columns, err := rows.Columns() if err != nil { return s.PushError("query: failed to get columns: %v", err) } // Build result array s.CreateTable(0, 0) rowIndex := 1 for rows.Next() { // Create values slice for scanning values := make([]any, len(columns)) valuePtrs := make([]any, len(columns)) for i := range values { valuePtrs[i] = &values[i] } if err := rows.Scan(valuePtrs...); err != nil { return s.PushError("query: scan error: %v", err) } // Create row table s.CreateTable(0, len(columns)) for i, col := range columns { s.PushString(col) if err := s.PushValue(values[i]); err != nil { s.PushNil() } s.SetTable(-3) } // Add to result array s.PushNumber(float64(rowIndex)) s.PushCopy(-2) s.SetTable(-4) s.Pop(1) // Remove row table copy rowIndex++ } if err := rows.Err(); err != nil { return s.PushError("query: %v", err) } return 1 } func luaExec(s *luajit.State) int { if err := s.CheckMinArgs(2); err != nil { return s.PushError("exec: %v", err) } connID, err := s.SafeToString(1) if err != nil { return s.PushError("exec: connection id must be a string") } query, err := s.SafeToString(2) if err != nil { return s.PushError("exec: query must be a string") } conn, exists := GetConnection(connID) if !exists { return s.PushError("exec: connection not found") } // Collect arguments args := make([]any, s.GetTop()-2) for i := 3; i <= s.GetTop(); i++ { val, err := s.ToValue(i) if err != nil { args[i-3] = nil } else { args[i-3] = val } } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() result, err := conn.Exec(ctx, query, args...) if err != nil { return s.PushError("exec: %v", err) } // Return result info s.CreateTable(0, 2) lastID, _ := result.LastInsertId() s.PushString("last_insert_id") s.PushNumber(float64(lastID)) s.SetTable(-3) affected, _ := result.RowsAffected() s.PushString("rows_affected") s.PushNumber(float64(affected)) s.SetTable(-3) return 1 } // GetFunctionList returns all Lua-callable functions func GetFunctionList() map[string]luajit.GoFunction { return map[string]luajit.GoFunction{ "sql_connect": luaConnect, "sql_close": luaClose, "sql_ping": luaPing, "sql_query": luaQuery, "sql_exec": luaExec, } } func init() { // Register SQLite driver on import RegisterDriver("sqlite", &SQLiteDriver{}) }