major sqlite bug fix, minor lua state closing fix, add headers to lua ctx
This commit is contained in:
parent
5b698f31e4
commit
266da9fd23
@ -210,10 +210,17 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
|||||||
luaCtx.Set("host", string(ctx.Host()))
|
luaCtx.Set("host", string(ctx.Host()))
|
||||||
luaCtx.Set("session", sessionMap)
|
luaCtx.Set("session", sessionMap)
|
||||||
|
|
||||||
|
// Add headers to context
|
||||||
|
headers := make(map[string]any)
|
||||||
|
ctx.Request.Header.VisitAll(func(key, value []byte) {
|
||||||
|
headers[string(key)] = string(value)
|
||||||
|
})
|
||||||
|
luaCtx.Set("headers", headers)
|
||||||
|
|
||||||
// Handle params
|
// Handle params
|
||||||
if params != nil && params.Count > 0 {
|
if params != nil && params.Count > 0 {
|
||||||
paramMap := s.paramsPool.Get().(map[string]any)
|
paramMap := s.paramsPool.Get().(map[string]any)
|
||||||
for i := 0; i < params.Count; i++ {
|
for i := range params.Count {
|
||||||
paramMap[params.Keys[i]] = params.Values[i]
|
paramMap[params.Keys[i]] = params.Values[i]
|
||||||
}
|
}
|
||||||
luaCtx.Set("params", paramMap)
|
luaCtx.Set("params", paramMap)
|
||||||
|
@ -16,9 +16,7 @@ local connection_mt = {
|
|||||||
end
|
end
|
||||||
|
|
||||||
local normalized_params = normalize_params(params, ...)
|
local normalized_params = normalize_params(params, ...)
|
||||||
local results, token = __sqlite_query(self.db_name, query, normalized_params, self.conn_token)
|
return __sqlite_query(self.db_name, query, normalized_params)
|
||||||
self.conn_token = token
|
|
||||||
return results
|
|
||||||
end,
|
end,
|
||||||
|
|
||||||
exec = function(self, query, params, ...)
|
exec = function(self, query, params, ...)
|
||||||
@ -27,18 +25,16 @@ local connection_mt = {
|
|||||||
end
|
end
|
||||||
|
|
||||||
local normalized_params = normalize_params(params, ...)
|
local normalized_params = normalize_params(params, ...)
|
||||||
local affected, token = __sqlite_exec(self.db_name, query, normalized_params, self.conn_token)
|
return __sqlite_exec(self.db_name, query, normalized_params)
|
||||||
self.conn_token = token
|
|
||||||
return affected
|
|
||||||
end,
|
end,
|
||||||
|
|
||||||
close = function(self)
|
get_one = function(self, query, params, ...)
|
||||||
if self.conn_token then
|
if type(query) ~= "string" then
|
||||||
local success = __sqlite_close(self.conn_token)
|
error("connection:get_one: query must be a string", 2)
|
||||||
self.conn_token = nil
|
|
||||||
return success
|
|
||||||
end
|
end
|
||||||
return false
|
|
||||||
|
local normalized_params = normalize_params(params, ...)
|
||||||
|
return __sqlite_get_one(self.db_name, query, normalized_params)
|
||||||
end,
|
end,
|
||||||
|
|
||||||
insert = function(self, table_name, data, columns)
|
insert = function(self, table_name, data, columns)
|
||||||
@ -249,20 +245,6 @@ local connection_mt = {
|
|||||||
return self:exec(query, normalize_params(params, ...))
|
return self:exec(query, normalize_params(params, ...))
|
||||||
end,
|
end,
|
||||||
|
|
||||||
get_one = function(self, query, params, ...)
|
|
||||||
if type(query) ~= "string" then
|
|
||||||
error("connection:get_one: query must be a string", 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
local limited_query = query
|
|
||||||
if not string.contains(query:lower(), "limit") then
|
|
||||||
limited_query = query .. " LIMIT 1"
|
|
||||||
end
|
|
||||||
|
|
||||||
local results = self:query(limited_query, normalize_params(params, ...))
|
|
||||||
return results[1]
|
|
||||||
end,
|
|
||||||
|
|
||||||
exists = function(self, table_name, where, params, ...)
|
exists = function(self, table_name, where, params, ...)
|
||||||
if type(table_name) ~= "string" then
|
if type(table_name) ~= "string" then
|
||||||
error("connection:exists: table_name must be a string", 2)
|
error("connection:exists: table_name must be a string", 2)
|
||||||
@ -310,7 +292,6 @@ return function(db_name)
|
|||||||
end
|
end
|
||||||
|
|
||||||
return setmetatable({
|
return setmetatable({
|
||||||
db_name = db_name,
|
db_name = db_name
|
||||||
conn_token = nil
|
|
||||||
}, connection_mt)
|
}, connection_mt)
|
||||||
end
|
end
|
||||||
|
101
runner/runner.go
101
runner/runner.go
@ -33,7 +33,7 @@ type State struct {
|
|||||||
L *luajit.State // The Lua state
|
L *luajit.State // The Lua state
|
||||||
sandbox *Sandbox // Associated sandbox
|
sandbox *Sandbox // Associated sandbox
|
||||||
index int // Index for debugging
|
index int // Index for debugging
|
||||||
inUse bool // Whether the state is currently in use
|
inUse atomic.Bool // Whether the state is currently in use
|
||||||
}
|
}
|
||||||
|
|
||||||
// Runner runs Lua scripts using a pool of Lua states
|
// Runner runs Lua scripts using a pool of Lua states
|
||||||
@ -115,14 +115,16 @@ func NewRunner(options ...RunnerOption) (*Runner, error) {
|
|||||||
InitSQLite(runner.dataDir)
|
InitSQLite(runner.dataDir)
|
||||||
InitFS(runner.fsDir)
|
InitFS(runner.fsDir)
|
||||||
|
|
||||||
|
SetSQLitePoolSize(runner.poolSize)
|
||||||
|
|
||||||
// Initialize states and pool
|
// Initialize states and pool
|
||||||
runner.states = make([]*State, runner.poolSize)
|
runner.states = make([]*State, runner.poolSize)
|
||||||
runner.statePool = make(chan int, runner.poolSize)
|
runner.statePool = make(chan int, runner.poolSize)
|
||||||
|
|
||||||
// Create and initialize all states
|
// Create and initialize all states
|
||||||
if err := runner.initializeStates(); err != nil {
|
if err := runner.initializeStates(); err != nil {
|
||||||
CleanupSQLite() // Clean up SQLite connections
|
CleanupSQLite()
|
||||||
runner.Close() // Clean up already created states
|
runner.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,7 +192,6 @@ func (r *Runner) createState(index int) (*State, error) {
|
|||||||
L: L,
|
L: L,
|
||||||
sandbox: sb,
|
sandbox: sb,
|
||||||
index: index,
|
index: index,
|
||||||
inUse: false,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,29 +216,26 @@ func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context,
|
|||||||
// Got a state
|
// Got a state
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
case <-time.After(5 * time.Second):
|
case <-time.After(1 * time.Second):
|
||||||
return nil, ErrTimeout
|
return nil, ErrTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the actual state
|
|
||||||
state := r.states[stateIndex]
|
state := r.states[stateIndex]
|
||||||
if state == nil {
|
if state == nil {
|
||||||
r.statePool <- stateIndex
|
r.statePool <- stateIndex
|
||||||
return nil, ErrStateNotReady
|
return nil, ErrStateNotReady
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark state as in use
|
// Use atomic operations
|
||||||
state.inUse = true
|
state.inUse.Store(true)
|
||||||
|
|
||||||
// Ensure state is returned to pool when done
|
|
||||||
defer func() {
|
defer func() {
|
||||||
state.inUse = false
|
state.inUse.Store(false)
|
||||||
if r.isRunning.Load() {
|
if r.isRunning.Load() {
|
||||||
select {
|
select {
|
||||||
case r.statePool <- stateIndex:
|
case r.statePool <- stateIndex:
|
||||||
// State returned to pool
|
|
||||||
default:
|
default:
|
||||||
// Pool is full or closed
|
// Pool is full or closed, state will be cleaned up by Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -267,21 +265,45 @@ func (r *Runner) Close() error {
|
|||||||
|
|
||||||
r.isRunning.Store(false)
|
r.isRunning.Store(false)
|
||||||
|
|
||||||
// Drain the state pool
|
// Drain all states from the pool
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-r.statePool:
|
case <-r.statePool:
|
||||||
// Drain one state
|
|
||||||
default:
|
default:
|
||||||
// Pool is empty
|
goto waitForInUse
|
||||||
goto cleanup
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanup:
|
waitForInUse:
|
||||||
// Clean up all states
|
// Wait for in-use states to finish (with timeout)
|
||||||
|
timeout := time.Now().Add(10 * time.Second)
|
||||||
|
for {
|
||||||
|
allIdle := true
|
||||||
|
for _, state := range r.states {
|
||||||
|
if state != nil && state.inUse.Load() {
|
||||||
|
allIdle = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if allIdle {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().After(timeout) {
|
||||||
|
logger.Warning("Timeout waiting for states to finish during shutdown, forcing close")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now safely close all states
|
||||||
for i, state := range r.states {
|
for i, state := range r.states {
|
||||||
if state != nil {
|
if state != nil {
|
||||||
|
if state.inUse.Load() {
|
||||||
|
logger.Warning("Force closing state %d that is still in use", i)
|
||||||
|
}
|
||||||
state.L.Cleanup()
|
state.L.Cleanup()
|
||||||
state.L.Close()
|
state.L.Close()
|
||||||
r.states[i] = nil
|
r.states[i] = nil
|
||||||
@ -310,19 +332,40 @@ func (r *Runner) RefreshStates() error {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-r.statePool:
|
case <-r.statePool:
|
||||||
// Drain one state
|
|
||||||
default:
|
default:
|
||||||
// Pool is empty
|
goto waitForInUse
|
||||||
goto cleanup
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanup:
|
waitForInUse:
|
||||||
// Destroy all existing states
|
// Wait for in-use states to finish (with timeout)
|
||||||
|
timeout := time.Now().Add(10 * time.Second)
|
||||||
|
for {
|
||||||
|
allIdle := true
|
||||||
|
for _, state := range r.states {
|
||||||
|
if state != nil && state.inUse.Load() {
|
||||||
|
allIdle = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if allIdle {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().After(timeout) {
|
||||||
|
logger.Warning("Timeout waiting for states to finish, forcing refresh")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now safely destroy all states
|
||||||
for i, state := range r.states {
|
for i, state := range r.states {
|
||||||
if state != nil {
|
if state != nil {
|
||||||
if state.inUse {
|
if state.inUse.Load() {
|
||||||
logger.Warning("Attempting to refresh state %d that is in use", i)
|
logger.Warning("Force closing state %d that is still in use", i)
|
||||||
}
|
}
|
||||||
state.L.Cleanup()
|
state.L.Cleanup()
|
||||||
state.L.Close()
|
state.L.Close()
|
||||||
@ -367,7 +410,7 @@ func (r *Runner) RefreshModule(moduleName string) bool {
|
|||||||
|
|
||||||
success := true
|
success := true
|
||||||
for _, state := range r.states {
|
for _, state := range r.states {
|
||||||
if state == nil || state.inUse {
|
if state == nil || state.inUse.Load() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -403,7 +446,7 @@ func (r *Runner) GetActiveStateCount() int {
|
|||||||
|
|
||||||
count := 0
|
count := 0
|
||||||
for _, state := range r.states {
|
for _, state := range r.states {
|
||||||
if state != nil && state.inUse {
|
if state != nil && state.inUse.Load() {
|
||||||
count++
|
count++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -459,10 +502,10 @@ func (r *Runner) RunScriptFile(filePath string) (*Response, error) {
|
|||||||
return nil, ErrStateNotReady
|
return nil, ErrStateNotReady
|
||||||
}
|
}
|
||||||
|
|
||||||
state.inUse = true
|
state.inUse.Store(true)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
state.inUse = false
|
state.inUse.Store(false)
|
||||||
if r.isRunning.Load() {
|
if r.isRunning.Load() {
|
||||||
select {
|
select {
|
||||||
case r.statePool <- stateIndex:
|
case r.statePool <- stateIndex:
|
||||||
|
408
runner/sqlite.go
408
runner/sqlite.go
@ -2,8 +2,6 @@ package runner
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@ -19,71 +17,29 @@ import (
|
|||||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DbPools maintains database connection pools
|
|
||||||
var (
|
var (
|
||||||
dbPools = make(map[string]*sqlitex.Pool)
|
dbPools = make(map[string]*sqlitex.Pool)
|
||||||
poolsMu sync.RWMutex
|
poolsMu sync.RWMutex
|
||||||
dataDir string
|
dataDir string
|
||||||
|
poolSize = 8 // Default, will be set to match runner pool size
|
||||||
// Connection tracking
|
connTimeout = 5 * time.Second
|
||||||
activeConns = make(map[string]*TrackedConn)
|
|
||||||
activeConnMu sync.RWMutex
|
|
||||||
connTimeout = 5 * time.Minute
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TrackedConn holds a connection with usage tracking
|
|
||||||
type TrackedConn struct {
|
|
||||||
Conn *sqlite.Conn
|
|
||||||
Pool *sqlitex.Pool
|
|
||||||
DBName string
|
|
||||||
LastUsed time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateConnToken creates a unique token for connection tracking
|
|
||||||
func generateConnToken() string {
|
|
||||||
b := make([]byte, 8)
|
|
||||||
rand.Read(b)
|
|
||||||
return base64.URLEncoding.EncodeToString(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitSQLite initializes the SQLite subsystem
|
// InitSQLite initializes the SQLite subsystem
|
||||||
func InitSQLite(dir string) {
|
func InitSQLite(dir string) {
|
||||||
dataDir = dir
|
dataDir = dir
|
||||||
logger.Info("SQLite is g2g! %s", color.Apply(dir, color.Yellow))
|
logger.Info("SQLite is g2g! %s", color.Apply(dir, color.Yellow))
|
||||||
|
|
||||||
// Start connection cleanup goroutine
|
|
||||||
go cleanupIdleConnections()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanupIdleConnections periodically checks for and removes idle connections
|
// SetSQLitePoolSize sets the pool size to match the runner pool size
|
||||||
func cleanupIdleConnections() {
|
func SetSQLitePoolSize(size int) {
|
||||||
ticker := time.NewTicker(30 * time.Second)
|
if size > 0 {
|
||||||
defer ticker.Stop()
|
poolSize = size
|
||||||
|
|
||||||
for range ticker.C {
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
activeConnMu.Lock()
|
|
||||||
for token, conn := range activeConns {
|
|
||||||
if conn.LastUsed.Add(connTimeout).Before(now) {
|
|
||||||
logger.Debug("Closing idle connection: %s (%s)", token, conn.DBName)
|
|
||||||
conn.Pool.Put(conn.Conn)
|
|
||||||
delete(activeConns, token)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
activeConnMu.Unlock()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanupSQLite closes all database connections
|
// CleanupSQLite closes all database connections
|
||||||
func CleanupSQLite() {
|
func CleanupSQLite() {
|
||||||
activeConnMu.Lock()
|
|
||||||
for token, conn := range activeConns {
|
|
||||||
conn.Pool.Put(conn.Conn)
|
|
||||||
delete(activeConns, token)
|
|
||||||
}
|
|
||||||
activeConnMu.Unlock()
|
|
||||||
|
|
||||||
poolsMu.Lock()
|
poolsMu.Lock()
|
||||||
defer poolsMu.Unlock()
|
defer poolsMu.Unlock()
|
||||||
|
|
||||||
@ -123,74 +79,36 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
|
|||||||
return pool, nil
|
return pool, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new pool
|
// Create new pool with proper size
|
||||||
dbPath := filepath.Join(dataDir, dbName+".db")
|
dbPath := filepath.Join(dataDir, dbName+".db")
|
||||||
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{})
|
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
|
||||||
|
PoolSize: poolSize,
|
||||||
|
PrepareConn: func(conn *sqlite.Conn) error {
|
||||||
|
// Execute PRAGMA statements individually
|
||||||
|
pragmas := []string{
|
||||||
|
"PRAGMA journal_mode = WAL",
|
||||||
|
"PRAGMA synchronous = NORMAL",
|
||||||
|
"PRAGMA cache_size = 1000",
|
||||||
|
"PRAGMA foreign_keys = ON",
|
||||||
|
"PRAGMA temp_store = MEMORY",
|
||||||
|
}
|
||||||
|
for _, pragma := range pragmas {
|
||||||
|
if err := sqlitex.ExecuteTransient(conn, pragma, nil); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dbPools[dbName] = pool
|
dbPools[dbName] = pool
|
||||||
|
logger.Debug("Created SQLite pool for %s (size: %d)", dbName, poolSize)
|
||||||
return pool, nil
|
return pool, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getConnection retrieves or creates a tracked connection
|
|
||||||
func getConnection(token, dbName string) (*TrackedConn, string, error) {
|
|
||||||
// If token is provided, try to get existing connection
|
|
||||||
if token != "" {
|
|
||||||
activeConnMu.RLock()
|
|
||||||
conn, exists := activeConns[token]
|
|
||||||
activeConnMu.RUnlock()
|
|
||||||
|
|
||||||
if exists {
|
|
||||||
conn.LastUsed = time.Now()
|
|
||||||
return conn, token, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Token not provided or connection not found, create new
|
|
||||||
pool, err := getPool(dbName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := pool.Take(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate new token
|
|
||||||
newToken := generateConnToken()
|
|
||||||
|
|
||||||
trackedConn := &TrackedConn{
|
|
||||||
Conn: conn,
|
|
||||||
Pool: pool,
|
|
||||||
DBName: dbName,
|
|
||||||
LastUsed: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
activeConnMu.Lock()
|
|
||||||
activeConns[newToken] = trackedConn
|
|
||||||
activeConnMu.Unlock()
|
|
||||||
|
|
||||||
return trackedConn, newToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// releaseConnection releases a connection back to the pool
|
|
||||||
func releaseConnection(token string) bool {
|
|
||||||
activeConnMu.Lock()
|
|
||||||
defer activeConnMu.Unlock()
|
|
||||||
|
|
||||||
conn, exists := activeConns[token]
|
|
||||||
if !exists {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
conn.Pool.Put(conn.Conn)
|
|
||||||
delete(activeConns, token)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// sqlQuery executes a SQL query and returns results
|
// sqlQuery executes a SQL query and returns results
|
||||||
func sqlQuery(state *luajit.State) int {
|
func sqlQuery(state *luajit.State) int {
|
||||||
// Get required parameters
|
// Get required parameters
|
||||||
@ -202,20 +120,23 @@ func sqlQuery(state *luajit.State) int {
|
|||||||
dbName := state.ToString(1)
|
dbName := state.ToString(1)
|
||||||
query := state.ToString(2)
|
query := state.ToString(2)
|
||||||
|
|
||||||
// Get connection token (optional)
|
// Get pool
|
||||||
var connToken string
|
pool, err := getPool(dbName)
|
||||||
if state.GetTop() >= 4 && state.IsString(4) {
|
|
||||||
connToken = state.ToString(4)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get connection
|
|
||||||
trackedConn, newToken, err := getConnection(connToken, dbName)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
conn := trackedConn.Conn
|
// Get connection with timeout
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := pool.Take(ctx)
|
||||||
|
if err != nil {
|
||||||
|
state.PushString(fmt.Sprintf("sqlite.query: connection timeout: %s", err.Error()))
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
defer pool.Put(conn)
|
||||||
|
|
||||||
// Create execution options
|
// Create execution options
|
||||||
var execOpts sqlitex.ExecOptions
|
var execOpts sqlitex.ExecOptions
|
||||||
@ -223,65 +144,10 @@ func sqlQuery(state *luajit.State) int {
|
|||||||
|
|
||||||
// Set up parameters if provided
|
// Set up parameters if provided
|
||||||
if state.GetTop() >= 3 && !state.IsNil(3) {
|
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||||
if state.IsTable(3) {
|
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||||
params, err := state.ToTable(3)
|
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
||||||
if err != nil {
|
|
||||||
state.PushString(fmt.Sprintf("sqlite.query: invalid parameters: %s", err.Error()))
|
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for array-style params
|
|
||||||
if arr, ok := params[""]; ok {
|
|
||||||
if arrParams, ok := arr.([]any); ok {
|
|
||||||
execOpts.Args = arrParams
|
|
||||||
} else if floatArr, ok := arr.([]float64); ok {
|
|
||||||
args := make([]any, len(floatArr))
|
|
||||||
for i, v := range floatArr {
|
|
||||||
args[i] = v
|
|
||||||
}
|
|
||||||
execOpts.Args = args
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Named parameters
|
|
||||||
named := make(map[string]any, len(params))
|
|
||||||
for k, v := range params {
|
|
||||||
if len(k) > 0 && k[0] != ':' {
|
|
||||||
named[":"+k] = v
|
|
||||||
} else {
|
|
||||||
named[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
execOpts.Named = named
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Positional parameters
|
|
||||||
count := state.GetTop() - 2
|
|
||||||
if state.IsString(4) {
|
|
||||||
count-- // Don't include connection token
|
|
||||||
}
|
|
||||||
args := make([]any, count)
|
|
||||||
for i := range count {
|
|
||||||
idx := i + 3
|
|
||||||
switch state.GetType(idx) {
|
|
||||||
case luajit.TypeNumber:
|
|
||||||
args[i] = state.ToNumber(idx)
|
|
||||||
case luajit.TypeString:
|
|
||||||
args[i] = state.ToString(idx)
|
|
||||||
case luajit.TypeBoolean:
|
|
||||||
args[i] = state.ToBoolean(idx)
|
|
||||||
case luajit.TypeNil:
|
|
||||||
args[i] = nil
|
|
||||||
default:
|
|
||||||
val, err := state.ToValue(idx)
|
|
||||||
if err != nil {
|
|
||||||
state.PushString(fmt.Sprintf("sqlite.query: invalid parameter %d: %s", i+1, err.Error()))
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
args[i] = val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
execOpts.Args = args
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up result function
|
// Set up result function
|
||||||
@ -300,8 +166,12 @@ func sqlQuery(state *luajit.State) int {
|
|||||||
row[colName] = stmt.ColumnText(i)
|
row[colName] = stmt.ColumnText(i)
|
||||||
case sqlite.TypeBlob:
|
case sqlite.TypeBlob:
|
||||||
blobSize := stmt.ColumnLen(i)
|
blobSize := stmt.ColumnLen(i)
|
||||||
|
if blobSize > 0 {
|
||||||
buf := make([]byte, blobSize)
|
buf := make([]byte, blobSize)
|
||||||
row[colName] = stmt.ColumnBytes(i, buf)
|
row[colName] = stmt.ColumnBytes(i, buf)
|
||||||
|
} else {
|
||||||
|
row[colName] = []byte{}
|
||||||
|
}
|
||||||
case sqlite.TypeNull:
|
case sqlite.TypeNull:
|
||||||
row[colName] = nil
|
row[colName] = nil
|
||||||
}
|
}
|
||||||
@ -327,10 +197,7 @@ func sqlQuery(state *luajit.State) int {
|
|||||||
state.SetTable(-3)
|
state.SetTable(-3)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return connection token
|
return 1
|
||||||
state.PushString(newToken)
|
|
||||||
|
|
||||||
return 2
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sqlExec executes a SQL statement without returning results
|
// sqlExec executes a SQL statement without returning results
|
||||||
@ -344,58 +211,73 @@ func sqlExec(state *luajit.State) int {
|
|||||||
dbName := state.ToString(1)
|
dbName := state.ToString(1)
|
||||||
query := state.ToString(2)
|
query := state.ToString(2)
|
||||||
|
|
||||||
// Get connection token (optional)
|
// Get pool
|
||||||
var connToken string
|
pool, err := getPool(dbName)
|
||||||
if state.GetTop() >= 4 && state.IsString(4) {
|
|
||||||
connToken = state.ToString(4)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get connection
|
|
||||||
trackedConn, newToken, err := getConnection(connToken, dbName)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
conn := trackedConn.Conn
|
// Get connection with timeout
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := pool.Take(ctx)
|
||||||
|
if err != nil {
|
||||||
|
state.PushString(fmt.Sprintf("sqlite.exec: connection timeout: %s", err.Error()))
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
defer pool.Put(conn)
|
||||||
|
|
||||||
// Check if parameters are provided
|
// Check if parameters are provided
|
||||||
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
|
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
|
||||||
|
|
||||||
// Fast path for multi-statement scripts - use ExecScript
|
// Fast path for multi-statement scripts
|
||||||
if strings.Contains(query, ";") && !hasParams {
|
if strings.Contains(query, ";") && !hasParams {
|
||||||
if err := sqlitex.ExecScript(conn, query); err != nil {
|
if err := sqlitex.ExecScript(conn, query); err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
state.PushNumber(float64(conn.Changes()))
|
state.PushNumber(float64(conn.Changes()))
|
||||||
state.PushString(newToken)
|
return 1
|
||||||
return 2
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fast path for simple queries with no parameters
|
// Fast path for simple queries with no parameters
|
||||||
if !hasParams {
|
if !hasParams {
|
||||||
// Use Execute for simple statements without parameters
|
|
||||||
if err := sqlitex.Execute(conn, query, nil); err != nil {
|
if err := sqlitex.Execute(conn, query, nil); err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
state.PushNumber(float64(conn.Changes()))
|
state.PushNumber(float64(conn.Changes()))
|
||||||
state.PushString(newToken)
|
return 1
|
||||||
return 2
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create execution options for parameterized query
|
// Create execution options for parameterized query
|
||||||
var execOpts sqlitex.ExecOptions
|
var execOpts sqlitex.ExecOptions
|
||||||
|
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||||
// Set up parameters
|
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||||
if state.IsTable(3) {
|
|
||||||
params, err := state.ToTable(3)
|
|
||||||
if err != nil {
|
|
||||||
state.PushString(fmt.Sprintf("sqlite.exec: invalid parameters: %s", err.Error()))
|
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute with parameters
|
||||||
|
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||||
|
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return affected rows
|
||||||
|
state.PushNumber(float64(conn.Changes()))
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupParams configures execution options with parameters from Lua
|
||||||
|
func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error {
|
||||||
|
if state.IsTable(paramIndex) {
|
||||||
|
params, err := state.ToTable(paramIndex)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid parameters: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Check for array-style params
|
// Check for array-style params
|
||||||
if arr, ok := params[""]; ok {
|
if arr, ok := params[""]; ok {
|
||||||
if arrParams, ok := arr.([]any); ok {
|
if arrParams, ok := arr.([]any); ok {
|
||||||
@ -420,59 +302,111 @@ func sqlExec(state *luajit.State) int {
|
|||||||
execOpts.Named = named
|
execOpts.Named = named
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Positional parameters
|
// Positional parameters from stack
|
||||||
count := state.GetTop() - 2
|
count := state.GetTop() - 2
|
||||||
if state.IsString(4) {
|
|
||||||
count-- // Don't include connection token
|
|
||||||
}
|
|
||||||
args := make([]any, count)
|
args := make([]any, count)
|
||||||
for i := range count {
|
for i := range count {
|
||||||
idx := i + 3
|
idx := i + 3
|
||||||
switch state.GetType(idx) {
|
|
||||||
case luajit.TypeNumber:
|
|
||||||
args[i] = state.ToNumber(idx)
|
|
||||||
case luajit.TypeString:
|
|
||||||
args[i] = state.ToString(idx)
|
|
||||||
case luajit.TypeBoolean:
|
|
||||||
args[i] = state.ToBoolean(idx)
|
|
||||||
case luajit.TypeNil:
|
|
||||||
args[i] = nil
|
|
||||||
default:
|
|
||||||
val, err := state.ToValue(idx)
|
val, err := state.ToValue(idx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString(fmt.Sprintf("sqlite.exec: invalid parameter %d: %s", i+1, err.Error()))
|
return fmt.Errorf("invalid parameter %d: %w", i+1, err)
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
args[i] = val
|
args[i] = val
|
||||||
}
|
}
|
||||||
}
|
|
||||||
execOpts.Args = args
|
execOpts.Args = args
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute with parameters
|
return nil
|
||||||
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
|
||||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return affected rows and connection token
|
|
||||||
state.PushNumber(float64(conn.Changes()))
|
|
||||||
state.PushString(newToken)
|
|
||||||
return 2
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sqlClose releases a connection back to the pool
|
// sqlGetOne executes a query and returns only the first row
|
||||||
func sqlClose(state *luajit.State) int {
|
func sqlGetOne(state *luajit.State) int {
|
||||||
if state.GetTop() < 1 || !state.IsString(1) {
|
// Get required parameters
|
||||||
state.PushString("sqlite.close: requires connection token")
|
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
|
||||||
|
state.PushString("sqlite.get_one: requires database name and query")
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
token := state.ToString(1)
|
dbName := state.ToString(1)
|
||||||
if releaseConnection(token) {
|
query := state.ToString(2)
|
||||||
state.PushBoolean(true)
|
|
||||||
|
// Get pool
|
||||||
|
pool, err := getPool(dbName)
|
||||||
|
if err != nil {
|
||||||
|
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get connection with timeout
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := pool.Take(ctx)
|
||||||
|
if err != nil {
|
||||||
|
state.PushString(fmt.Sprintf("sqlite.get_one: connection timeout: %s", err.Error()))
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
defer pool.Put(conn)
|
||||||
|
|
||||||
|
// Create execution options
|
||||||
|
var execOpts sqlitex.ExecOptions
|
||||||
|
var result map[string]any
|
||||||
|
|
||||||
|
// Set up parameters if provided
|
||||||
|
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||||
|
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||||
|
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up result function to get only first row
|
||||||
|
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
||||||
|
if result != nil {
|
||||||
|
return nil // Already got first row
|
||||||
|
}
|
||||||
|
|
||||||
|
result = make(map[string]any)
|
||||||
|
colCount := stmt.ColumnCount()
|
||||||
|
|
||||||
|
for i := range colCount {
|
||||||
|
colName := stmt.ColumnName(i)
|
||||||
|
switch stmt.ColumnType(i) {
|
||||||
|
case sqlite.TypeInteger:
|
||||||
|
result[colName] = stmt.ColumnInt64(i)
|
||||||
|
case sqlite.TypeFloat:
|
||||||
|
result[colName] = stmt.ColumnFloat(i)
|
||||||
|
case sqlite.TypeText:
|
||||||
|
result[colName] = stmt.ColumnText(i)
|
||||||
|
case sqlite.TypeBlob:
|
||||||
|
blobSize := stmt.ColumnLen(i)
|
||||||
|
if blobSize > 0 {
|
||||||
|
buf := make([]byte, blobSize)
|
||||||
|
result[colName] = stmt.ColumnBytes(i, buf)
|
||||||
} else {
|
} else {
|
||||||
state.PushBoolean(false)
|
result[colName] = []byte{}
|
||||||
|
}
|
||||||
|
case sqlite.TypeNull:
|
||||||
|
result[colName] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute query
|
||||||
|
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||||
|
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return result or nil if no rows
|
||||||
|
if result == nil {
|
||||||
|
state.PushNil()
|
||||||
|
} else {
|
||||||
|
if err := state.PushTable(result); err != nil {
|
||||||
|
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
|
||||||
|
return -1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
@ -486,7 +420,7 @@ func RegisterSQLiteFunctions(state *luajit.State) error {
|
|||||||
if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil {
|
if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := state.RegisterGoFunction("__sqlite_close", sqlClose); err != nil {
|
if err := state.RegisterGoFunction("__sqlite_get_one", sqlGetOne); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
Loading…
x
Reference in New Issue
Block a user