major sqlite bug fix, minor lua state closing fix, add headers to lua ctx
This commit is contained in:
parent
5b698f31e4
commit
266da9fd23
@ -210,10 +210,17 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
|
||||
luaCtx.Set("host", string(ctx.Host()))
|
||||
luaCtx.Set("session", sessionMap)
|
||||
|
||||
// Add headers to context
|
||||
headers := make(map[string]any)
|
||||
ctx.Request.Header.VisitAll(func(key, value []byte) {
|
||||
headers[string(key)] = string(value)
|
||||
})
|
||||
luaCtx.Set("headers", headers)
|
||||
|
||||
// Handle params
|
||||
if params != nil && params.Count > 0 {
|
||||
paramMap := s.paramsPool.Get().(map[string]any)
|
||||
for i := 0; i < params.Count; i++ {
|
||||
for i := range params.Count {
|
||||
paramMap[params.Keys[i]] = params.Values[i]
|
||||
}
|
||||
luaCtx.Set("params", paramMap)
|
||||
|
@ -16,9 +16,7 @@ local connection_mt = {
|
||||
end
|
||||
|
||||
local normalized_params = normalize_params(params, ...)
|
||||
local results, token = __sqlite_query(self.db_name, query, normalized_params, self.conn_token)
|
||||
self.conn_token = token
|
||||
return results
|
||||
return __sqlite_query(self.db_name, query, normalized_params)
|
||||
end,
|
||||
|
||||
exec = function(self, query, params, ...)
|
||||
@ -27,18 +25,16 @@ local connection_mt = {
|
||||
end
|
||||
|
||||
local normalized_params = normalize_params(params, ...)
|
||||
local affected, token = __sqlite_exec(self.db_name, query, normalized_params, self.conn_token)
|
||||
self.conn_token = token
|
||||
return affected
|
||||
return __sqlite_exec(self.db_name, query, normalized_params)
|
||||
end,
|
||||
|
||||
close = function(self)
|
||||
if self.conn_token then
|
||||
local success = __sqlite_close(self.conn_token)
|
||||
self.conn_token = nil
|
||||
return success
|
||||
get_one = function(self, query, params, ...)
|
||||
if type(query) ~= "string" then
|
||||
error("connection:get_one: query must be a string", 2)
|
||||
end
|
||||
return false
|
||||
|
||||
local normalized_params = normalize_params(params, ...)
|
||||
return __sqlite_get_one(self.db_name, query, normalized_params)
|
||||
end,
|
||||
|
||||
insert = function(self, table_name, data, columns)
|
||||
@ -249,20 +245,6 @@ local connection_mt = {
|
||||
return self:exec(query, normalize_params(params, ...))
|
||||
end,
|
||||
|
||||
get_one = function(self, query, params, ...)
|
||||
if type(query) ~= "string" then
|
||||
error("connection:get_one: query must be a string", 2)
|
||||
end
|
||||
|
||||
local limited_query = query
|
||||
if not string.contains(query:lower(), "limit") then
|
||||
limited_query = query .. " LIMIT 1"
|
||||
end
|
||||
|
||||
local results = self:query(limited_query, normalize_params(params, ...))
|
||||
return results[1]
|
||||
end,
|
||||
|
||||
exists = function(self, table_name, where, params, ...)
|
||||
if type(table_name) ~= "string" then
|
||||
error("connection:exists: table_name must be a string", 2)
|
||||
@ -310,7 +292,6 @@ return function(db_name)
|
||||
end
|
||||
|
||||
return setmetatable({
|
||||
db_name = db_name,
|
||||
conn_token = nil
|
||||
db_name = db_name
|
||||
}, connection_mt)
|
||||
end
|
||||
|
101
runner/runner.go
101
runner/runner.go
@ -33,7 +33,7 @@ type State struct {
|
||||
L *luajit.State // The Lua state
|
||||
sandbox *Sandbox // Associated sandbox
|
||||
index int // Index for debugging
|
||||
inUse bool // Whether the state is currently in use
|
||||
inUse atomic.Bool // Whether the state is currently in use
|
||||
}
|
||||
|
||||
// Runner runs Lua scripts using a pool of Lua states
|
||||
@ -115,14 +115,16 @@ func NewRunner(options ...RunnerOption) (*Runner, error) {
|
||||
InitSQLite(runner.dataDir)
|
||||
InitFS(runner.fsDir)
|
||||
|
||||
SetSQLitePoolSize(runner.poolSize)
|
||||
|
||||
// Initialize states and pool
|
||||
runner.states = make([]*State, runner.poolSize)
|
||||
runner.statePool = make(chan int, runner.poolSize)
|
||||
|
||||
// Create and initialize all states
|
||||
if err := runner.initializeStates(); err != nil {
|
||||
CleanupSQLite() // Clean up SQLite connections
|
||||
runner.Close() // Clean up already created states
|
||||
CleanupSQLite()
|
||||
runner.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -190,7 +192,6 @@ func (r *Runner) createState(index int) (*State, error) {
|
||||
L: L,
|
||||
sandbox: sb,
|
||||
index: index,
|
||||
inUse: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -215,29 +216,26 @@ func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context,
|
||||
// Got a state
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(5 * time.Second):
|
||||
case <-time.After(1 * time.Second):
|
||||
return nil, ErrTimeout
|
||||
}
|
||||
|
||||
// Get the actual state
|
||||
state := r.states[stateIndex]
|
||||
if state == nil {
|
||||
r.statePool <- stateIndex
|
||||
return nil, ErrStateNotReady
|
||||
}
|
||||
|
||||
// Mark state as in use
|
||||
state.inUse = true
|
||||
// Use atomic operations
|
||||
state.inUse.Store(true)
|
||||
|
||||
// Ensure state is returned to pool when done
|
||||
defer func() {
|
||||
state.inUse = false
|
||||
state.inUse.Store(false)
|
||||
if r.isRunning.Load() {
|
||||
select {
|
||||
case r.statePool <- stateIndex:
|
||||
// State returned to pool
|
||||
default:
|
||||
// Pool is full or closed
|
||||
// Pool is full or closed, state will be cleaned up by Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
@ -267,21 +265,45 @@ func (r *Runner) Close() error {
|
||||
|
||||
r.isRunning.Store(false)
|
||||
|
||||
// Drain the state pool
|
||||
// Drain all states from the pool
|
||||
for {
|
||||
select {
|
||||
case <-r.statePool:
|
||||
// Drain one state
|
||||
default:
|
||||
// Pool is empty
|
||||
goto cleanup
|
||||
goto waitForInUse
|
||||
}
|
||||
}
|
||||
|
||||
cleanup:
|
||||
// Clean up all states
|
||||
waitForInUse:
|
||||
// Wait for in-use states to finish (with timeout)
|
||||
timeout := time.Now().Add(10 * time.Second)
|
||||
for {
|
||||
allIdle := true
|
||||
for _, state := range r.states {
|
||||
if state != nil && state.inUse.Load() {
|
||||
allIdle = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allIdle {
|
||||
break
|
||||
}
|
||||
|
||||
if time.Now().After(timeout) {
|
||||
logger.Warning("Timeout waiting for states to finish during shutdown, forcing close")
|
||||
break
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Now safely close all states
|
||||
for i, state := range r.states {
|
||||
if state != nil {
|
||||
if state.inUse.Load() {
|
||||
logger.Warning("Force closing state %d that is still in use", i)
|
||||
}
|
||||
state.L.Cleanup()
|
||||
state.L.Close()
|
||||
r.states[i] = nil
|
||||
@ -310,19 +332,40 @@ func (r *Runner) RefreshStates() error {
|
||||
for {
|
||||
select {
|
||||
case <-r.statePool:
|
||||
// Drain one state
|
||||
default:
|
||||
// Pool is empty
|
||||
goto cleanup
|
||||
goto waitForInUse
|
||||
}
|
||||
}
|
||||
|
||||
cleanup:
|
||||
// Destroy all existing states
|
||||
waitForInUse:
|
||||
// Wait for in-use states to finish (with timeout)
|
||||
timeout := time.Now().Add(10 * time.Second)
|
||||
for {
|
||||
allIdle := true
|
||||
for _, state := range r.states {
|
||||
if state != nil && state.inUse.Load() {
|
||||
allIdle = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allIdle {
|
||||
break
|
||||
}
|
||||
|
||||
if time.Now().After(timeout) {
|
||||
logger.Warning("Timeout waiting for states to finish, forcing refresh")
|
||||
break
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Now safely destroy all states
|
||||
for i, state := range r.states {
|
||||
if state != nil {
|
||||
if state.inUse {
|
||||
logger.Warning("Attempting to refresh state %d that is in use", i)
|
||||
if state.inUse.Load() {
|
||||
logger.Warning("Force closing state %d that is still in use", i)
|
||||
}
|
||||
state.L.Cleanup()
|
||||
state.L.Close()
|
||||
@ -367,7 +410,7 @@ func (r *Runner) RefreshModule(moduleName string) bool {
|
||||
|
||||
success := true
|
||||
for _, state := range r.states {
|
||||
if state == nil || state.inUse {
|
||||
if state == nil || state.inUse.Load() {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -403,7 +446,7 @@ func (r *Runner) GetActiveStateCount() int {
|
||||
|
||||
count := 0
|
||||
for _, state := range r.states {
|
||||
if state != nil && state.inUse {
|
||||
if state != nil && state.inUse.Load() {
|
||||
count++
|
||||
}
|
||||
}
|
||||
@ -459,10 +502,10 @@ func (r *Runner) RunScriptFile(filePath string) (*Response, error) {
|
||||
return nil, ErrStateNotReady
|
||||
}
|
||||
|
||||
state.inUse = true
|
||||
state.inUse.Store(true)
|
||||
|
||||
defer func() {
|
||||
state.inUse = false
|
||||
state.inUse.Store(false)
|
||||
if r.isRunning.Load() {
|
||||
select {
|
||||
case r.statePool <- stateIndex:
|
||||
|
424
runner/sqlite.go
424
runner/sqlite.go
@ -2,8 +2,6 @@ package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@ -19,71 +17,29 @@ import (
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// DbPools maintains database connection pools
|
||||
var (
|
||||
dbPools = make(map[string]*sqlitex.Pool)
|
||||
poolsMu sync.RWMutex
|
||||
dataDir string
|
||||
|
||||
// Connection tracking
|
||||
activeConns = make(map[string]*TrackedConn)
|
||||
activeConnMu sync.RWMutex
|
||||
connTimeout = 5 * time.Minute
|
||||
dbPools = make(map[string]*sqlitex.Pool)
|
||||
poolsMu sync.RWMutex
|
||||
dataDir string
|
||||
poolSize = 8 // Default, will be set to match runner pool size
|
||||
connTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// TrackedConn holds a connection with usage tracking
|
||||
type TrackedConn struct {
|
||||
Conn *sqlite.Conn
|
||||
Pool *sqlitex.Pool
|
||||
DBName string
|
||||
LastUsed time.Time
|
||||
}
|
||||
|
||||
// generateConnToken creates a unique token for connection tracking
|
||||
func generateConnToken() string {
|
||||
b := make([]byte, 8)
|
||||
rand.Read(b)
|
||||
return base64.URLEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
// InitSQLite initializes the SQLite subsystem
|
||||
func InitSQLite(dir string) {
|
||||
dataDir = dir
|
||||
logger.Info("SQLite is g2g! %s", color.Apply(dir, color.Yellow))
|
||||
|
||||
// Start connection cleanup goroutine
|
||||
go cleanupIdleConnections()
|
||||
}
|
||||
|
||||
// cleanupIdleConnections periodically checks for and removes idle connections
|
||||
func cleanupIdleConnections() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
|
||||
activeConnMu.Lock()
|
||||
for token, conn := range activeConns {
|
||||
if conn.LastUsed.Add(connTimeout).Before(now) {
|
||||
logger.Debug("Closing idle connection: %s (%s)", token, conn.DBName)
|
||||
conn.Pool.Put(conn.Conn)
|
||||
delete(activeConns, token)
|
||||
}
|
||||
}
|
||||
activeConnMu.Unlock()
|
||||
// SetSQLitePoolSize sets the pool size to match the runner pool size
|
||||
func SetSQLitePoolSize(size int) {
|
||||
if size > 0 {
|
||||
poolSize = size
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupSQLite closes all database connections
|
||||
func CleanupSQLite() {
|
||||
activeConnMu.Lock()
|
||||
for token, conn := range activeConns {
|
||||
conn.Pool.Put(conn.Conn)
|
||||
delete(activeConns, token)
|
||||
}
|
||||
activeConnMu.Unlock()
|
||||
|
||||
poolsMu.Lock()
|
||||
defer poolsMu.Unlock()
|
||||
|
||||
@ -123,74 +79,36 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// Create new pool
|
||||
// Create new pool with proper size
|
||||
dbPath := filepath.Join(dataDir, dbName+".db")
|
||||
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{})
|
||||
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
|
||||
PoolSize: poolSize,
|
||||
PrepareConn: func(conn *sqlite.Conn) error {
|
||||
// Execute PRAGMA statements individually
|
||||
pragmas := []string{
|
||||
"PRAGMA journal_mode = WAL",
|
||||
"PRAGMA synchronous = NORMAL",
|
||||
"PRAGMA cache_size = 1000",
|
||||
"PRAGMA foreign_keys = ON",
|
||||
"PRAGMA temp_store = MEMORY",
|
||||
}
|
||||
for _, pragma := range pragmas {
|
||||
if err := sqlitex.ExecuteTransient(conn, pragma, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
dbPools[dbName] = pool
|
||||
logger.Debug("Created SQLite pool for %s (size: %d)", dbName, poolSize)
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// getConnection retrieves or creates a tracked connection
|
||||
func getConnection(token, dbName string) (*TrackedConn, string, error) {
|
||||
// If token is provided, try to get existing connection
|
||||
if token != "" {
|
||||
activeConnMu.RLock()
|
||||
conn, exists := activeConns[token]
|
||||
activeConnMu.RUnlock()
|
||||
|
||||
if exists {
|
||||
conn.LastUsed = time.Now()
|
||||
return conn, token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Token not provided or connection not found, create new
|
||||
pool, err := getPool(dbName)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
conn, err := pool.Take(context.Background())
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Generate new token
|
||||
newToken := generateConnToken()
|
||||
|
||||
trackedConn := &TrackedConn{
|
||||
Conn: conn,
|
||||
Pool: pool,
|
||||
DBName: dbName,
|
||||
LastUsed: time.Now(),
|
||||
}
|
||||
|
||||
activeConnMu.Lock()
|
||||
activeConns[newToken] = trackedConn
|
||||
activeConnMu.Unlock()
|
||||
|
||||
return trackedConn, newToken, nil
|
||||
}
|
||||
|
||||
// releaseConnection releases a connection back to the pool
|
||||
func releaseConnection(token string) bool {
|
||||
activeConnMu.Lock()
|
||||
defer activeConnMu.Unlock()
|
||||
|
||||
conn, exists := activeConns[token]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
conn.Pool.Put(conn.Conn)
|
||||
delete(activeConns, token)
|
||||
return true
|
||||
}
|
||||
|
||||
// sqlQuery executes a SQL query and returns results
|
||||
func sqlQuery(state *luajit.State) int {
|
||||
// Get required parameters
|
||||
@ -202,20 +120,23 @@ func sqlQuery(state *luajit.State) int {
|
||||
dbName := state.ToString(1)
|
||||
query := state.ToString(2)
|
||||
|
||||
// Get connection token (optional)
|
||||
var connToken string
|
||||
if state.GetTop() >= 4 && state.IsString(4) {
|
||||
connToken = state.ToString(4)
|
||||
}
|
||||
|
||||
// Get connection
|
||||
trackedConn, newToken, err := getConnection(connToken, dbName)
|
||||
// Get pool
|
||||
pool, err := getPool(dbName)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
conn := trackedConn.Conn
|
||||
// Get connection with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pool.Take(ctx)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.query: connection timeout: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
defer pool.Put(conn)
|
||||
|
||||
// Create execution options
|
||||
var execOpts sqlitex.ExecOptions
|
||||
@ -223,64 +144,9 @@ func sqlQuery(state *luajit.State) int {
|
||||
|
||||
// Set up parameters if provided
|
||||
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||
if state.IsTable(3) {
|
||||
params, err := state.ToTable(3)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.query: invalid parameters: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Check for array-style params
|
||||
if arr, ok := params[""]; ok {
|
||||
if arrParams, ok := arr.([]any); ok {
|
||||
execOpts.Args = arrParams
|
||||
} else if floatArr, ok := arr.([]float64); ok {
|
||||
args := make([]any, len(floatArr))
|
||||
for i, v := range floatArr {
|
||||
args[i] = v
|
||||
}
|
||||
execOpts.Args = args
|
||||
}
|
||||
} else {
|
||||
// Named parameters
|
||||
named := make(map[string]any, len(params))
|
||||
for k, v := range params {
|
||||
if len(k) > 0 && k[0] != ':' {
|
||||
named[":"+k] = v
|
||||
} else {
|
||||
named[k] = v
|
||||
}
|
||||
}
|
||||
execOpts.Named = named
|
||||
}
|
||||
} else {
|
||||
// Positional parameters
|
||||
count := state.GetTop() - 2
|
||||
if state.IsString(4) {
|
||||
count-- // Don't include connection token
|
||||
}
|
||||
args := make([]any, count)
|
||||
for i := range count {
|
||||
idx := i + 3
|
||||
switch state.GetType(idx) {
|
||||
case luajit.TypeNumber:
|
||||
args[i] = state.ToNumber(idx)
|
||||
case luajit.TypeString:
|
||||
args[i] = state.ToString(idx)
|
||||
case luajit.TypeBoolean:
|
||||
args[i] = state.ToBoolean(idx)
|
||||
case luajit.TypeNil:
|
||||
args[i] = nil
|
||||
default:
|
||||
val, err := state.ToValue(idx)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.query: invalid parameter %d: %s", i+1, err.Error()))
|
||||
return -1
|
||||
}
|
||||
args[i] = val
|
||||
}
|
||||
}
|
||||
execOpts.Args = args
|
||||
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
@ -300,8 +166,12 @@ func sqlQuery(state *luajit.State) int {
|
||||
row[colName] = stmt.ColumnText(i)
|
||||
case sqlite.TypeBlob:
|
||||
blobSize := stmt.ColumnLen(i)
|
||||
buf := make([]byte, blobSize)
|
||||
row[colName] = stmt.ColumnBytes(i, buf)
|
||||
if blobSize > 0 {
|
||||
buf := make([]byte, blobSize)
|
||||
row[colName] = stmt.ColumnBytes(i, buf)
|
||||
} else {
|
||||
row[colName] = []byte{}
|
||||
}
|
||||
case sqlite.TypeNull:
|
||||
row[colName] = nil
|
||||
}
|
||||
@ -327,10 +197,7 @@ func sqlQuery(state *luajit.State) int {
|
||||
state.SetTable(-3)
|
||||
}
|
||||
|
||||
// Return connection token
|
||||
state.PushString(newToken)
|
||||
|
||||
return 2
|
||||
return 1
|
||||
}
|
||||
|
||||
// sqlExec executes a SQL statement without returning results
|
||||
@ -344,56 +211,71 @@ func sqlExec(state *luajit.State) int {
|
||||
dbName := state.ToString(1)
|
||||
query := state.ToString(2)
|
||||
|
||||
// Get connection token (optional)
|
||||
var connToken string
|
||||
if state.GetTop() >= 4 && state.IsString(4) {
|
||||
connToken = state.ToString(4)
|
||||
}
|
||||
|
||||
// Get connection
|
||||
trackedConn, newToken, err := getConnection(connToken, dbName)
|
||||
// Get pool
|
||||
pool, err := getPool(dbName)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
conn := trackedConn.Conn
|
||||
// Get connection with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pool.Take(ctx)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: connection timeout: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
defer pool.Put(conn)
|
||||
|
||||
// Check if parameters are provided
|
||||
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
|
||||
|
||||
// Fast path for multi-statement scripts - use ExecScript
|
||||
// Fast path for multi-statement scripts
|
||||
if strings.Contains(query, ";") && !hasParams {
|
||||
if err := sqlitex.ExecScript(conn, query); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
state.PushNumber(float64(conn.Changes()))
|
||||
state.PushString(newToken)
|
||||
return 2
|
||||
return 1
|
||||
}
|
||||
|
||||
// Fast path for simple queries with no parameters
|
||||
if !hasParams {
|
||||
// Use Execute for simple statements without parameters
|
||||
if err := sqlitex.Execute(conn, query, nil); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
state.PushNumber(float64(conn.Changes()))
|
||||
state.PushString(newToken)
|
||||
return 2
|
||||
return 1
|
||||
}
|
||||
|
||||
// Create execution options for parameterized query
|
||||
var execOpts sqlitex.ExecOptions
|
||||
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Set up parameters
|
||||
if state.IsTable(3) {
|
||||
params, err := state.ToTable(3)
|
||||
// Execute with parameters
|
||||
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Return affected rows
|
||||
state.PushNumber(float64(conn.Changes()))
|
||||
return 1
|
||||
}
|
||||
|
||||
// setupParams configures execution options with parameters from Lua
|
||||
func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error {
|
||||
if state.IsTable(paramIndex) {
|
||||
params, err := state.ToTable(paramIndex)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: invalid parameters: %s", err.Error()))
|
||||
return -1
|
||||
return fmt.Errorf("invalid parameters: %w", err)
|
||||
}
|
||||
|
||||
// Check for array-style params
|
||||
@ -420,59 +302,111 @@ func sqlExec(state *luajit.State) int {
|
||||
execOpts.Named = named
|
||||
}
|
||||
} else {
|
||||
// Positional parameters
|
||||
// Positional parameters from stack
|
||||
count := state.GetTop() - 2
|
||||
if state.IsString(4) {
|
||||
count-- // Don't include connection token
|
||||
}
|
||||
args := make([]any, count)
|
||||
for i := range count {
|
||||
idx := i + 3
|
||||
switch state.GetType(idx) {
|
||||
case luajit.TypeNumber:
|
||||
args[i] = state.ToNumber(idx)
|
||||
case luajit.TypeString:
|
||||
args[i] = state.ToString(idx)
|
||||
case luajit.TypeBoolean:
|
||||
args[i] = state.ToBoolean(idx)
|
||||
case luajit.TypeNil:
|
||||
args[i] = nil
|
||||
default:
|
||||
val, err := state.ToValue(idx)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: invalid parameter %d: %s", i+1, err.Error()))
|
||||
return -1
|
||||
}
|
||||
args[i] = val
|
||||
val, err := state.ToValue(idx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid parameter %d: %w", i+1, err)
|
||||
}
|
||||
args[i] = val
|
||||
}
|
||||
execOpts.Args = args
|
||||
}
|
||||
|
||||
// Execute with parameters
|
||||
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Return affected rows and connection token
|
||||
state.PushNumber(float64(conn.Changes()))
|
||||
state.PushString(newToken)
|
||||
return 2
|
||||
return nil
|
||||
}
|
||||
|
||||
// sqlClose releases a connection back to the pool
|
||||
func sqlClose(state *luajit.State) int {
|
||||
if state.GetTop() < 1 || !state.IsString(1) {
|
||||
state.PushString("sqlite.close: requires connection token")
|
||||
// sqlGetOne executes a query and returns only the first row
|
||||
func sqlGetOne(state *luajit.State) int {
|
||||
// Get required parameters
|
||||
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
|
||||
state.PushString("sqlite.get_one: requires database name and query")
|
||||
return -1
|
||||
}
|
||||
|
||||
token := state.ToString(1)
|
||||
if releaseConnection(token) {
|
||||
state.PushBoolean(true)
|
||||
dbName := state.ToString(1)
|
||||
query := state.ToString(2)
|
||||
|
||||
// Get pool
|
||||
pool, err := getPool(dbName)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Get connection with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pool.Take(ctx)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.get_one: connection timeout: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
defer pool.Put(conn)
|
||||
|
||||
// Create execution options
|
||||
var execOpts sqlitex.ExecOptions
|
||||
var result map[string]any
|
||||
|
||||
// Set up parameters if provided
|
||||
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
// Set up result function to get only first row
|
||||
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
||||
if result != nil {
|
||||
return nil // Already got first row
|
||||
}
|
||||
|
||||
result = make(map[string]any)
|
||||
colCount := stmt.ColumnCount()
|
||||
|
||||
for i := range colCount {
|
||||
colName := stmt.ColumnName(i)
|
||||
switch stmt.ColumnType(i) {
|
||||
case sqlite.TypeInteger:
|
||||
result[colName] = stmt.ColumnInt64(i)
|
||||
case sqlite.TypeFloat:
|
||||
result[colName] = stmt.ColumnFloat(i)
|
||||
case sqlite.TypeText:
|
||||
result[colName] = stmt.ColumnText(i)
|
||||
case sqlite.TypeBlob:
|
||||
blobSize := stmt.ColumnLen(i)
|
||||
if blobSize > 0 {
|
||||
buf := make([]byte, blobSize)
|
||||
result[colName] = stmt.ColumnBytes(i, buf)
|
||||
} else {
|
||||
result[colName] = []byte{}
|
||||
}
|
||||
case sqlite.TypeNull:
|
||||
result[colName] = nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute query
|
||||
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Return result or nil if no rows
|
||||
if result == nil {
|
||||
state.PushNil()
|
||||
} else {
|
||||
state.PushBoolean(false)
|
||||
if err := state.PushTable(result); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
return 1
|
||||
@ -486,7 +420,7 @@ func RegisterSQLiteFunctions(state *luajit.State) error {
|
||||
if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := state.RegisterGoFunction("__sqlite_close", sqlClose); err != nil {
|
||||
if err := state.RegisterGoFunction("__sqlite_get_one", sqlGetOne); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
Loading…
x
Reference in New Issue
Block a user