378 lines
8.0 KiB
Go

package sql
import (
"context"
"fmt"
"sync"
"time"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// Driver interface for SQL database implementations
type Driver interface {
Open(dsn string) (Connection, error)
Name() string
}
// Connection represents a database connection
type Connection interface {
Close() error
Ping(ctx context.Context) error
Begin(ctx context.Context) (Transaction, error)
Query(ctx context.Context, query string, args ...any) (Rows, error)
QueryRow(ctx context.Context, query string, args ...any) Row
Exec(ctx context.Context, query string, args ...any) (Result, error)
Prepare(ctx context.Context, query string) (Statement, error)
}
// Transaction represents a database transaction
type Transaction interface {
Commit() error
Rollback() error
Query(ctx context.Context, query string, args ...any) (Rows, error)
QueryRow(ctx context.Context, query string, args ...any) Row
Exec(ctx context.Context, query string, args ...any) (Result, error)
Prepare(ctx context.Context, query string) (Statement, error)
}
// Rows represents query result rows
type Rows interface {
Next() bool
Scan(dest ...any) error
Columns() ([]string, error)
Close() error
Err() error
}
// Row represents a single query result row
type Row interface {
Scan(dest ...any) error
}
// Result represents the result of an executed statement
type Result interface {
LastInsertId() (int64, error)
RowsAffected() (int64, error)
}
// Statement represents a prepared statement
type Statement interface {
Close() error
Query(ctx context.Context, args ...any) (Rows, error)
QueryRow(ctx context.Context, args ...any) Row
Exec(ctx context.Context, args ...any) (Result, error)
}
// Registry manages database drivers and connections
type Registry struct {
mu sync.RWMutex
drivers map[string]Driver
conns map[string]Connection
nextID int
}
var global = &Registry{
drivers: make(map[string]Driver),
conns: make(map[string]Connection),
}
// RegisterDriver registers a database driver
func RegisterDriver(name string, driver Driver) {
global.mu.Lock()
defer global.mu.Unlock()
global.drivers[name] = driver
}
// GetDriver returns a registered driver
func GetDriver(name string) (Driver, bool) {
global.mu.RLock()
defer global.mu.RUnlock()
driver, exists := global.drivers[name]
return driver, exists
}
// Connect opens a new database connection
func Connect(driverName, dsn string) (string, error) {
driver, exists := GetDriver(driverName)
if !exists {
return "", fmt.Errorf("unknown driver: %s", driverName)
}
conn, err := driver.Open(dsn)
if err != nil {
return "", err
}
global.mu.Lock()
defer global.mu.Unlock()
id := fmt.Sprintf("%s_%d", driverName, global.nextID)
global.nextID++
global.conns[id] = conn
return id, nil
}
// GetConnection retrieves a connection by ID
func GetConnection(id string) (Connection, bool) {
global.mu.RLock()
defer global.mu.RUnlock()
conn, exists := global.conns[id]
return conn, exists
}
// CloseConnection closes and removes a connection
func CloseConnection(id string) error {
global.mu.Lock()
defer global.mu.Unlock()
conn, exists := global.conns[id]
if !exists {
return fmt.Errorf("connection not found: %s", id)
}
err := conn.Close()
delete(global.conns, id)
return err
}
func CloseAllConnections() {
global.mu.Lock()
defer global.mu.Unlock()
for id, conn := range global.conns {
conn.Close()
delete(global.conns, id)
}
}
// Lua function implementations
func luaConnect(s *luajit.State) int {
if err := s.CheckExactArgs(2); err != nil {
return s.PushError("connect: %v", err)
}
driver, err := s.SafeToString(1)
if err != nil {
return s.PushError("connect: driver must be a string")
}
dsn, err := s.SafeToString(2)
if err != nil {
return s.PushError("connect: dsn must be a string")
}
connID, err := Connect(driver, dsn)
if err != nil {
return s.PushError("connect: %v", err)
}
s.PushString(connID)
return 1
}
func luaClose(s *luajit.State) int {
if err := s.CheckExactArgs(1); err != nil {
return s.PushError("close: %v", err)
}
connID, err := s.SafeToString(1)
if err != nil {
return s.PushError("close: connection id must be a string")
}
if err := CloseConnection(connID); err != nil {
return s.PushError("close: %v", err)
}
s.PushBoolean(true)
return 1
}
func luaPing(s *luajit.State) int {
if err := s.CheckExactArgs(1); err != nil {
return s.PushError("ping: %v", err)
}
connID, err := s.SafeToString(1)
if err != nil {
return s.PushError("ping: connection id must be a string")
}
conn, exists := GetConnection(connID)
if !exists {
return s.PushError("ping: connection not found")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := conn.Ping(ctx); err != nil {
return s.PushError("ping: %v", err)
}
s.PushBoolean(true)
return 1
}
func luaQuery(s *luajit.State) int {
if err := s.CheckMinArgs(2); err != nil {
return s.PushError("query: %v", err)
}
connID, err := s.SafeToString(1)
if err != nil {
return s.PushError("query: connection id must be a string")
}
query, err := s.SafeToString(2)
if err != nil {
return s.PushError("query: query must be a string")
}
conn, exists := GetConnection(connID)
if !exists {
return s.PushError("query: connection not found")
}
// Collect arguments
args := make([]any, s.GetTop()-2)
for i := 3; i <= s.GetTop(); i++ {
val, err := s.ToValue(i)
if err != nil {
args[i-3] = nil
} else {
args[i-3] = val
}
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
rows, err := conn.Query(ctx, query, args...)
if err != nil {
return s.PushError("query: %v", err)
}
defer rows.Close()
// Get column names
columns, err := rows.Columns()
if err != nil {
return s.PushError("query: failed to get columns: %v", err)
}
// Build result array
s.CreateTable(0, 0)
rowIndex := 1
for rows.Next() {
// Create values slice for scanning
values := make([]any, len(columns))
valuePtrs := make([]any, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return s.PushError("query: scan error: %v", err)
}
// Create row table
s.CreateTable(0, len(columns))
for i, col := range columns {
s.PushString(col)
if err := s.PushValue(values[i]); err != nil {
s.PushNil()
}
s.SetTable(-3)
}
// Add to result array
s.PushNumber(float64(rowIndex))
s.PushCopy(-2)
s.SetTable(-4)
s.Pop(1) // Remove row table copy
rowIndex++
}
if err := rows.Err(); err != nil {
return s.PushError("query: %v", err)
}
return 1
}
func luaExec(s *luajit.State) int {
if err := s.CheckMinArgs(2); err != nil {
return s.PushError("exec: %v", err)
}
connID, err := s.SafeToString(1)
if err != nil {
return s.PushError("exec: connection id must be a string")
}
query, err := s.SafeToString(2)
if err != nil {
return s.PushError("exec: query must be a string")
}
conn, exists := GetConnection(connID)
if !exists {
return s.PushError("exec: connection not found")
}
// Collect arguments
args := make([]any, s.GetTop()-2)
for i := 3; i <= s.GetTop(); i++ {
val, err := s.ToValue(i)
if err != nil {
args[i-3] = nil
} else {
args[i-3] = val
}
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
result, err := conn.Exec(ctx, query, args...)
if err != nil {
return s.PushError("exec: %v", err)
}
// Return result info
s.CreateTable(0, 2)
lastID, _ := result.LastInsertId()
s.PushString("last_insert_id")
s.PushNumber(float64(lastID))
s.SetTable(-3)
affected, _ := result.RowsAffected()
s.PushString("rows_affected")
s.PushNumber(float64(affected))
s.SetTable(-3)
return 1
}
// GetFunctionList returns all Lua-callable functions
func GetFunctionList() map[string]luajit.GoFunction {
return map[string]luajit.GoFunction{
"sql_connect": luaConnect,
"sql_close": luaClose,
"sql_ping": luaPing,
"sql_query": luaQuery,
"sql_exec": luaExec,
}
}
func init() {
// Register SQLite driver on import
RegisterDriver("sqlite", &SQLiteDriver{})
}