368 lines
7.9 KiB
Go
368 lines
7.9 KiB
Go
package sql
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
)
|
|
|
|
// Driver interface for SQL database implementations
|
|
type Driver interface {
|
|
Open(dsn string) (Connection, error)
|
|
Name() string
|
|
}
|
|
|
|
// Connection represents a database connection
|
|
type Connection interface {
|
|
Close() error
|
|
Ping(ctx context.Context) error
|
|
Begin(ctx context.Context) (Transaction, error)
|
|
Query(ctx context.Context, query string, args ...any) (Rows, error)
|
|
QueryRow(ctx context.Context, query string, args ...any) Row
|
|
Exec(ctx context.Context, query string, args ...any) (Result, error)
|
|
Prepare(ctx context.Context, query string) (Statement, error)
|
|
}
|
|
|
|
// Transaction represents a database transaction
|
|
type Transaction interface {
|
|
Commit() error
|
|
Rollback() error
|
|
Query(ctx context.Context, query string, args ...any) (Rows, error)
|
|
QueryRow(ctx context.Context, query string, args ...any) Row
|
|
Exec(ctx context.Context, query string, args ...any) (Result, error)
|
|
Prepare(ctx context.Context, query string) (Statement, error)
|
|
}
|
|
|
|
// Rows represents query result rows
|
|
type Rows interface {
|
|
Next() bool
|
|
Scan(dest ...any) error
|
|
Columns() ([]string, error)
|
|
Close() error
|
|
Err() error
|
|
}
|
|
|
|
// Row represents a single query result row
|
|
type Row interface {
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
// Result represents the result of an executed statement
|
|
type Result interface {
|
|
LastInsertId() (int64, error)
|
|
RowsAffected() (int64, error)
|
|
}
|
|
|
|
// Statement represents a prepared statement
|
|
type Statement interface {
|
|
Close() error
|
|
Query(ctx context.Context, args ...any) (Rows, error)
|
|
QueryRow(ctx context.Context, args ...any) Row
|
|
Exec(ctx context.Context, args ...any) (Result, error)
|
|
}
|
|
|
|
// Registry manages database drivers and connections
|
|
type Registry struct {
|
|
mu sync.RWMutex
|
|
drivers map[string]Driver
|
|
conns map[string]Connection
|
|
nextID int
|
|
}
|
|
|
|
var global = &Registry{
|
|
drivers: make(map[string]Driver),
|
|
conns: make(map[string]Connection),
|
|
}
|
|
|
|
// RegisterDriver registers a database driver
|
|
func RegisterDriver(name string, driver Driver) {
|
|
global.mu.Lock()
|
|
defer global.mu.Unlock()
|
|
global.drivers[name] = driver
|
|
}
|
|
|
|
// GetDriver returns a registered driver
|
|
func GetDriver(name string) (Driver, bool) {
|
|
global.mu.RLock()
|
|
defer global.mu.RUnlock()
|
|
driver, exists := global.drivers[name]
|
|
return driver, exists
|
|
}
|
|
|
|
// Connect opens a new database connection
|
|
func Connect(driverName, dsn string) (string, error) {
|
|
driver, exists := GetDriver(driverName)
|
|
if !exists {
|
|
return "", fmt.Errorf("unknown driver: %s", driverName)
|
|
}
|
|
|
|
conn, err := driver.Open(dsn)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
global.mu.Lock()
|
|
defer global.mu.Unlock()
|
|
|
|
id := fmt.Sprintf("%s_%d", driverName, global.nextID)
|
|
global.nextID++
|
|
global.conns[id] = conn
|
|
|
|
return id, nil
|
|
}
|
|
|
|
// GetConnection retrieves a connection by ID
|
|
func GetConnection(id string) (Connection, bool) {
|
|
global.mu.RLock()
|
|
defer global.mu.RUnlock()
|
|
conn, exists := global.conns[id]
|
|
return conn, exists
|
|
}
|
|
|
|
// CloseConnection closes and removes a connection
|
|
func CloseConnection(id string) error {
|
|
global.mu.Lock()
|
|
defer global.mu.Unlock()
|
|
|
|
conn, exists := global.conns[id]
|
|
if !exists {
|
|
return fmt.Errorf("connection not found: %s", id)
|
|
}
|
|
|
|
err := conn.Close()
|
|
delete(global.conns, id)
|
|
return err
|
|
}
|
|
|
|
// Lua function implementations
|
|
|
|
func luaConnect(s *luajit.State) int {
|
|
if err := s.CheckExactArgs(2); err != nil {
|
|
return s.PushError("connect: %v", err)
|
|
}
|
|
|
|
driver, err := s.SafeToString(1)
|
|
if err != nil {
|
|
return s.PushError("connect: driver must be a string")
|
|
}
|
|
|
|
dsn, err := s.SafeToString(2)
|
|
if err != nil {
|
|
return s.PushError("connect: dsn must be a string")
|
|
}
|
|
|
|
connID, err := Connect(driver, dsn)
|
|
if err != nil {
|
|
return s.PushError("connect: %v", err)
|
|
}
|
|
|
|
s.PushString(connID)
|
|
return 1
|
|
}
|
|
|
|
func luaClose(s *luajit.State) int {
|
|
if err := s.CheckExactArgs(1); err != nil {
|
|
return s.PushError("close: %v", err)
|
|
}
|
|
|
|
connID, err := s.SafeToString(1)
|
|
if err != nil {
|
|
return s.PushError("close: connection id must be a string")
|
|
}
|
|
|
|
if err := CloseConnection(connID); err != nil {
|
|
return s.PushError("close: %v", err)
|
|
}
|
|
|
|
s.PushBoolean(true)
|
|
return 1
|
|
}
|
|
|
|
func luaPing(s *luajit.State) int {
|
|
if err := s.CheckExactArgs(1); err != nil {
|
|
return s.PushError("ping: %v", err)
|
|
}
|
|
|
|
connID, err := s.SafeToString(1)
|
|
if err != nil {
|
|
return s.PushError("ping: connection id must be a string")
|
|
}
|
|
|
|
conn, exists := GetConnection(connID)
|
|
if !exists {
|
|
return s.PushError("ping: connection not found")
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
if err := conn.Ping(ctx); err != nil {
|
|
return s.PushError("ping: %v", err)
|
|
}
|
|
|
|
s.PushBoolean(true)
|
|
return 1
|
|
}
|
|
|
|
func luaQuery(s *luajit.State) int {
|
|
if err := s.CheckMinArgs(2); err != nil {
|
|
return s.PushError("query: %v", err)
|
|
}
|
|
|
|
connID, err := s.SafeToString(1)
|
|
if err != nil {
|
|
return s.PushError("query: connection id must be a string")
|
|
}
|
|
|
|
query, err := s.SafeToString(2)
|
|
if err != nil {
|
|
return s.PushError("query: query must be a string")
|
|
}
|
|
|
|
conn, exists := GetConnection(connID)
|
|
if !exists {
|
|
return s.PushError("query: connection not found")
|
|
}
|
|
|
|
// Collect arguments
|
|
args := make([]any, s.GetTop()-2)
|
|
for i := 3; i <= s.GetTop(); i++ {
|
|
val, err := s.ToValue(i)
|
|
if err != nil {
|
|
args[i-3] = nil
|
|
} else {
|
|
args[i-3] = val
|
|
}
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
rows, err := conn.Query(ctx, query, args...)
|
|
if err != nil {
|
|
return s.PushError("query: %v", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
// Get column names
|
|
columns, err := rows.Columns()
|
|
if err != nil {
|
|
return s.PushError("query: failed to get columns: %v", err)
|
|
}
|
|
|
|
// Build result array
|
|
s.CreateTable(0, 0)
|
|
rowIndex := 1
|
|
|
|
for rows.Next() {
|
|
// Create values slice for scanning
|
|
values := make([]any, len(columns))
|
|
valuePtrs := make([]any, len(columns))
|
|
for i := range values {
|
|
valuePtrs[i] = &values[i]
|
|
}
|
|
|
|
if err := rows.Scan(valuePtrs...); err != nil {
|
|
return s.PushError("query: scan error: %v", err)
|
|
}
|
|
|
|
// Create row table
|
|
s.CreateTable(0, len(columns))
|
|
for i, col := range columns {
|
|
s.PushString(col)
|
|
if err := s.PushValue(values[i]); err != nil {
|
|
s.PushNil()
|
|
}
|
|
s.SetTable(-3)
|
|
}
|
|
|
|
// Add to result array
|
|
s.PushNumber(float64(rowIndex))
|
|
s.PushCopy(-2)
|
|
s.SetTable(-4)
|
|
s.Pop(1) // Remove row table copy
|
|
|
|
rowIndex++
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return s.PushError("query: %v", err)
|
|
}
|
|
|
|
return 1
|
|
}
|
|
|
|
func luaExec(s *luajit.State) int {
|
|
if err := s.CheckMinArgs(2); err != nil {
|
|
return s.PushError("exec: %v", err)
|
|
}
|
|
|
|
connID, err := s.SafeToString(1)
|
|
if err != nil {
|
|
return s.PushError("exec: connection id must be a string")
|
|
}
|
|
|
|
query, err := s.SafeToString(2)
|
|
if err != nil {
|
|
return s.PushError("exec: query must be a string")
|
|
}
|
|
|
|
conn, exists := GetConnection(connID)
|
|
if !exists {
|
|
return s.PushError("exec: connection not found")
|
|
}
|
|
|
|
// Collect arguments
|
|
args := make([]any, s.GetTop()-2)
|
|
for i := 3; i <= s.GetTop(); i++ {
|
|
val, err := s.ToValue(i)
|
|
if err != nil {
|
|
args[i-3] = nil
|
|
} else {
|
|
args[i-3] = val
|
|
}
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
result, err := conn.Exec(ctx, query, args...)
|
|
if err != nil {
|
|
return s.PushError("exec: %v", err)
|
|
}
|
|
|
|
// Return result info
|
|
s.CreateTable(0, 2)
|
|
|
|
lastID, _ := result.LastInsertId()
|
|
s.PushString("last_insert_id")
|
|
s.PushNumber(float64(lastID))
|
|
s.SetTable(-3)
|
|
|
|
affected, _ := result.RowsAffected()
|
|
s.PushString("rows_affected")
|
|
s.PushNumber(float64(affected))
|
|
s.SetTable(-3)
|
|
|
|
return 1
|
|
}
|
|
|
|
// GetFunctionList returns all Lua-callable functions
|
|
func GetFunctionList() map[string]luajit.GoFunction {
|
|
return map[string]luajit.GoFunction{
|
|
"sql_connect": luaConnect,
|
|
"sql_close": luaClose,
|
|
"sql_ping": luaPing,
|
|
"sql_query": luaQuery,
|
|
"sql_exec": luaExec,
|
|
}
|
|
}
|
|
|
|
func init() {
|
|
// Register SQLite driver on import
|
|
RegisterDriver("sqlite", &SQLiteDriver{})
|
|
}
|