467 lines
10 KiB
Go
467 lines
10 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
sqlite "zombiezen.com/go/sqlite"
|
|
"zombiezen.com/go/sqlite/sqlitex"
|
|
|
|
"Moonshark/logger"
|
|
|
|
"git.sharkk.net/Go/Color"
|
|
|
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
|
)
|
|
|
|
var (
|
|
dbPools = make(map[string]*sqlitex.Pool)
|
|
poolsMu sync.RWMutex
|
|
dataDir string
|
|
poolSize = 8
|
|
connTimeout = 5 * time.Second
|
|
|
|
// Per-state connection cache
|
|
stateConns = make(map[string]*stateConn)
|
|
stateConnsMu sync.RWMutex
|
|
)
|
|
|
|
// stateConn tracks a connection and its origin pool
|
|
type stateConn struct {
|
|
conn *sqlite.Conn
|
|
pool *sqlitex.Pool
|
|
}
|
|
|
|
func InitSQLite(dir string) {
|
|
dataDir = dir
|
|
logger.Infof("SQLite is g2g! %s", color.Yellow(dir))
|
|
}
|
|
|
|
func SetSQLitePoolSize(size int) {
|
|
if size > 0 {
|
|
poolSize = size
|
|
}
|
|
}
|
|
|
|
func CleanupSQLite() {
|
|
poolsMu.Lock()
|
|
defer poolsMu.Unlock()
|
|
|
|
// Return all cached connections to their pools
|
|
stateConnsMu.Lock()
|
|
for _, sc := range stateConns {
|
|
if sc.pool != nil && sc.conn != nil {
|
|
sc.pool.Put(sc.conn)
|
|
}
|
|
}
|
|
stateConns = make(map[string]*stateConn)
|
|
stateConnsMu.Unlock()
|
|
|
|
for name, pool := range dbPools {
|
|
if err := pool.Close(); err != nil {
|
|
logger.Errorf("Failed to close database %s: %v", name, err)
|
|
}
|
|
}
|
|
|
|
dbPools = make(map[string]*sqlitex.Pool)
|
|
logger.Debugf("SQLite connections closed")
|
|
}
|
|
|
|
func getPool(dbName string) (*sqlitex.Pool, error) {
|
|
dbName = filepath.Base(dbName)
|
|
if dbName == "" || dbName[0] == '.' {
|
|
return nil, fmt.Errorf("invalid database name")
|
|
}
|
|
|
|
poolsMu.RLock()
|
|
pool, exists := dbPools[dbName]
|
|
if exists {
|
|
poolsMu.RUnlock()
|
|
return pool, nil
|
|
}
|
|
poolsMu.RUnlock()
|
|
|
|
poolsMu.Lock()
|
|
defer poolsMu.Unlock()
|
|
|
|
if pool, exists = dbPools[dbName]; exists {
|
|
return pool, nil
|
|
}
|
|
|
|
dbPath := filepath.Join(dataDir, dbName+".db")
|
|
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
|
|
PoolSize: poolSize,
|
|
PrepareConn: func(conn *sqlite.Conn) error {
|
|
pragmas := []string{
|
|
"PRAGMA journal_mode = WAL",
|
|
"PRAGMA synchronous = NORMAL",
|
|
"PRAGMA cache_size = 1000",
|
|
"PRAGMA foreign_keys = ON",
|
|
"PRAGMA temp_store = MEMORY",
|
|
}
|
|
for _, pragma := range pragmas {
|
|
if err := sqlitex.ExecuteTransient(conn, pragma, nil); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
|
}
|
|
|
|
dbPools[dbName] = pool
|
|
logger.Debugf("Created SQLite pool for %s (size: %d)", dbName, poolSize)
|
|
return pool, nil
|
|
}
|
|
|
|
// getStateConnection gets or creates a reusable connection for the state+db
|
|
func getStateConnection(stateIndex int, dbName string) (*sqlite.Conn, error) {
|
|
connKey := fmt.Sprintf("%d-%s", stateIndex, dbName)
|
|
|
|
stateConnsMu.RLock()
|
|
sc, exists := stateConns[connKey]
|
|
stateConnsMu.RUnlock()
|
|
|
|
if exists && sc.conn != nil {
|
|
return sc.conn, nil
|
|
}
|
|
|
|
// Get new connection from pool
|
|
pool, err := getPool(dbName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
|
defer cancel()
|
|
|
|
conn, err := pool.Take(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("connection timeout: %w", err)
|
|
}
|
|
|
|
// Cache it with pool reference
|
|
stateConnsMu.Lock()
|
|
stateConns[connKey] = &stateConn{
|
|
conn: conn,
|
|
pool: pool,
|
|
}
|
|
stateConnsMu.Unlock()
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
func sqlQuery(state *luajit.State) int {
|
|
if err := state.CheckMinArgs(3); err != nil {
|
|
return state.PushError("sqlite.query: %v", err)
|
|
}
|
|
|
|
dbName, err := state.SafeToString(1)
|
|
if err != nil {
|
|
return state.PushError("sqlite.query: database name must be string")
|
|
}
|
|
|
|
query, err := state.SafeToString(2)
|
|
if err != nil {
|
|
return state.PushError("sqlite.query: query must be string")
|
|
}
|
|
|
|
stateIndex := int(state.ToNumber(-1))
|
|
|
|
conn, err := getStateConnection(stateIndex, dbName)
|
|
if err != nil {
|
|
return state.PushError("sqlite.query: %v", err)
|
|
}
|
|
|
|
var execOpts sqlitex.ExecOptions
|
|
rows := make([]any, 0, 16)
|
|
|
|
if state.GetTop() >= 4 && !state.IsNil(3) {
|
|
if err := setupParams(state, 3, &execOpts); err != nil {
|
|
return state.PushError("sqlite.query: %v", err)
|
|
}
|
|
}
|
|
|
|
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
|
row := make(map[string]any)
|
|
colCount := stmt.ColumnCount()
|
|
|
|
for i := range colCount {
|
|
colName := stmt.ColumnName(i)
|
|
switch stmt.ColumnType(i) {
|
|
case sqlite.TypeInteger:
|
|
row[colName] = stmt.ColumnInt64(i)
|
|
case sqlite.TypeFloat:
|
|
row[colName] = stmt.ColumnFloat(i)
|
|
case sqlite.TypeText:
|
|
row[colName] = stmt.ColumnText(i)
|
|
case sqlite.TypeBlob:
|
|
blobSize := stmt.ColumnLen(i)
|
|
if blobSize > 0 {
|
|
buf := make([]byte, blobSize)
|
|
row[colName] = stmt.ColumnBytes(i, buf)
|
|
} else {
|
|
row[colName] = []byte{}
|
|
}
|
|
case sqlite.TypeNull:
|
|
row[colName] = nil
|
|
}
|
|
}
|
|
rows = append(rows, row)
|
|
return nil
|
|
}
|
|
|
|
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
|
return state.PushError("sqlite.query: %v", err)
|
|
}
|
|
|
|
if err := state.PushValue(rows); err != nil {
|
|
return state.PushError("sqlite.query: %v", err)
|
|
}
|
|
|
|
return 1
|
|
}
|
|
|
|
func sqlExec(state *luajit.State) int {
|
|
if err := state.CheckMinArgs(3); err != nil {
|
|
return state.PushError("sqlite.exec: %v", err)
|
|
}
|
|
|
|
dbName, err := state.SafeToString(1)
|
|
if err != nil {
|
|
return state.PushError("sqlite.exec: database name must be string")
|
|
}
|
|
|
|
query, err := state.SafeToString(2)
|
|
if err != nil {
|
|
return state.PushError("sqlite.exec: query must be string")
|
|
}
|
|
|
|
stateIndex := int(state.ToNumber(-1))
|
|
|
|
conn, err := getStateConnection(stateIndex, dbName)
|
|
if err != nil {
|
|
return state.PushError("sqlite.exec: %v", err)
|
|
}
|
|
|
|
hasParams := state.GetTop() >= 4 && !state.IsNil(3)
|
|
|
|
if strings.Contains(query, ";") && !hasParams {
|
|
if err := sqlitex.ExecScript(conn, query); err != nil {
|
|
return state.PushError("sqlite.exec: %v", err)
|
|
}
|
|
state.PushNumber(float64(conn.Changes()))
|
|
return 1
|
|
}
|
|
|
|
if !hasParams {
|
|
if err := sqlitex.Execute(conn, query, nil); err != nil {
|
|
return state.PushError("sqlite.exec: %v", err)
|
|
}
|
|
state.PushNumber(float64(conn.Changes()))
|
|
return 1
|
|
}
|
|
|
|
var execOpts sqlitex.ExecOptions
|
|
if err := setupParams(state, 3, &execOpts); err != nil {
|
|
return state.PushError("sqlite.exec: %v", err)
|
|
}
|
|
|
|
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
|
return state.PushError("sqlite.exec: %v", err)
|
|
}
|
|
|
|
state.PushNumber(float64(conn.Changes()))
|
|
return 1
|
|
}
|
|
|
|
func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error {
|
|
if state.IsTable(paramIndex) {
|
|
paramsAny, err := state.ToTable(paramIndex)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid parameters: %w", err)
|
|
}
|
|
|
|
// Handle direct array types
|
|
if arrParams, ok := paramsAny.([]any); ok {
|
|
execOpts.Args = arrParams
|
|
return nil
|
|
}
|
|
if strArr, ok := paramsAny.([]string); ok {
|
|
args := make([]any, len(strArr))
|
|
for i, v := range strArr {
|
|
args[i] = v
|
|
}
|
|
execOpts.Args = args
|
|
return nil
|
|
}
|
|
if floatArr, ok := paramsAny.([]float64); ok {
|
|
args := make([]any, len(floatArr))
|
|
for i, v := range floatArr {
|
|
args[i] = v
|
|
}
|
|
execOpts.Args = args
|
|
return nil
|
|
}
|
|
|
|
params, ok := paramsAny.(map[string]any)
|
|
if !ok {
|
|
return fmt.Errorf("unsupported parameter type: %T", paramsAny)
|
|
}
|
|
|
|
// Check for array-style parameters (empty string key indicates array)
|
|
if arr, ok := params[""]; ok {
|
|
if arrParams, ok := arr.([]any); ok {
|
|
execOpts.Args = arrParams
|
|
} else if floatArr, ok := arr.([]float64); ok {
|
|
args := make([]any, len(floatArr))
|
|
for i, v := range floatArr {
|
|
args[i] = v
|
|
}
|
|
execOpts.Args = args
|
|
}
|
|
} else {
|
|
// Named parameters
|
|
named := make(map[string]any, len(params))
|
|
for k, v := range params {
|
|
if len(k) > 0 && k[0] != ':' {
|
|
named[":"+k] = v
|
|
} else {
|
|
named[k] = v
|
|
}
|
|
}
|
|
execOpts.Named = named
|
|
}
|
|
} else {
|
|
// Multiple individual parameters
|
|
count := state.GetTop() - 2
|
|
args := make([]any, count)
|
|
for i := range count {
|
|
idx := i + 3
|
|
val, err := state.ToValue(idx)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid parameter %d: %w", i+1, err)
|
|
}
|
|
args[i] = val
|
|
}
|
|
execOpts.Args = args
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func sqlGetOne(state *luajit.State) int {
|
|
if err := state.CheckMinArgs(3); err != nil {
|
|
return state.PushError("sqlite.get_one: %v", err)
|
|
}
|
|
|
|
dbName, err := state.SafeToString(1)
|
|
if err != nil {
|
|
return state.PushError("sqlite.get_one: database name must be string")
|
|
}
|
|
|
|
query, err := state.SafeToString(2)
|
|
if err != nil {
|
|
return state.PushError("sqlite.get_one: query must be string")
|
|
}
|
|
|
|
stateIndex := int(state.ToNumber(-1))
|
|
|
|
conn, err := getStateConnection(stateIndex, dbName)
|
|
if err != nil {
|
|
return state.PushError("sqlite.get_one: %v", err)
|
|
}
|
|
|
|
var execOpts sqlitex.ExecOptions
|
|
var result map[string]any
|
|
|
|
// Check if params provided (before state index)
|
|
if state.GetTop() >= 4 && !state.IsNil(3) {
|
|
if err := setupParams(state, 3, &execOpts); err != nil {
|
|
return state.PushError("sqlite.get_one: %v", err)
|
|
}
|
|
}
|
|
|
|
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
|
if result != nil {
|
|
return nil
|
|
}
|
|
|
|
result = make(map[string]any)
|
|
colCount := stmt.ColumnCount()
|
|
|
|
for i := range colCount {
|
|
colName := stmt.ColumnName(i)
|
|
switch stmt.ColumnType(i) {
|
|
case sqlite.TypeInteger:
|
|
result[colName] = stmt.ColumnInt64(i)
|
|
case sqlite.TypeFloat:
|
|
result[colName] = stmt.ColumnFloat(i)
|
|
case sqlite.TypeText:
|
|
result[colName] = stmt.ColumnText(i)
|
|
case sqlite.TypeBlob:
|
|
blobSize := stmt.ColumnLen(i)
|
|
if blobSize > 0 {
|
|
buf := make([]byte, blobSize)
|
|
result[colName] = stmt.ColumnBytes(i, buf)
|
|
} else {
|
|
result[colName] = []byte{}
|
|
}
|
|
case sqlite.TypeNull:
|
|
result[colName] = nil
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
|
return state.PushError("sqlite.get_one: %v", err)
|
|
}
|
|
|
|
if result == nil {
|
|
state.PushNil()
|
|
} else {
|
|
if err := state.PushValue(result); err != nil {
|
|
return state.PushError("sqlite.get_one: %v", err)
|
|
}
|
|
}
|
|
|
|
return 1
|
|
}
|
|
|
|
// CleanupStateConnection releases all connections for a specific state
|
|
func CleanupStateConnection(stateIndex int) {
|
|
stateConnsMu.Lock()
|
|
defer stateConnsMu.Unlock()
|
|
|
|
statePrefix := fmt.Sprintf("%d-", stateIndex)
|
|
|
|
for key, sc := range stateConns {
|
|
if strings.HasPrefix(key, statePrefix) {
|
|
if sc.pool != nil && sc.conn != nil {
|
|
sc.pool.Put(sc.conn)
|
|
}
|
|
delete(stateConns, key)
|
|
}
|
|
}
|
|
}
|
|
|
|
func RegisterSQLiteFunctions(state *luajit.State) error {
|
|
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
|
|
return err
|
|
}
|
|
if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil {
|
|
return err
|
|
}
|
|
if err := state.RegisterGoFunction("__sqlite_get_one", sqlGetOne); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|