major sqlite bug fix, minor lua state closing fix, add headers to lua ctx

This commit is contained in:
Sky Johnson 2025-05-30 13:24:58 -05:00
parent 5b698f31e4
commit 266da9fd23
4 changed files with 268 additions and 303 deletions

View File

@ -210,10 +210,17 @@ func (s *Server) handleLuaRoute(ctx *fasthttp.RequestCtx, bytecode []byte, scrip
luaCtx.Set("host", string(ctx.Host()))
luaCtx.Set("session", sessionMap)
// Add headers to context
headers := make(map[string]any)
ctx.Request.Header.VisitAll(func(key, value []byte) {
headers[string(key)] = string(value)
})
luaCtx.Set("headers", headers)
// Handle params
if params != nil && params.Count > 0 {
paramMap := s.paramsPool.Get().(map[string]any)
for i := 0; i < params.Count; i++ {
for i := range params.Count {
paramMap[params.Keys[i]] = params.Values[i]
}
luaCtx.Set("params", paramMap)

View File

@ -16,9 +16,7 @@ local connection_mt = {
end
local normalized_params = normalize_params(params, ...)
local results, token = __sqlite_query(self.db_name, query, normalized_params, self.conn_token)
self.conn_token = token
return results
return __sqlite_query(self.db_name, query, normalized_params)
end,
exec = function(self, query, params, ...)
@ -27,18 +25,16 @@ local connection_mt = {
end
local normalized_params = normalize_params(params, ...)
local affected, token = __sqlite_exec(self.db_name, query, normalized_params, self.conn_token)
self.conn_token = token
return affected
return __sqlite_exec(self.db_name, query, normalized_params)
end,
close = function(self)
if self.conn_token then
local success = __sqlite_close(self.conn_token)
self.conn_token = nil
return success
get_one = function(self, query, params, ...)
if type(query) ~= "string" then
error("connection:get_one: query must be a string", 2)
end
return false
local normalized_params = normalize_params(params, ...)
return __sqlite_get_one(self.db_name, query, normalized_params)
end,
insert = function(self, table_name, data, columns)
@ -249,20 +245,6 @@ local connection_mt = {
return self:exec(query, normalize_params(params, ...))
end,
get_one = function(self, query, params, ...)
if type(query) ~= "string" then
error("connection:get_one: query must be a string", 2)
end
local limited_query = query
if not string.contains(query:lower(), "limit") then
limited_query = query .. " LIMIT 1"
end
local results = self:query(limited_query, normalize_params(params, ...))
return results[1]
end,
exists = function(self, table_name, where, params, ...)
if type(table_name) ~= "string" then
error("connection:exists: table_name must be a string", 2)
@ -310,7 +292,6 @@ return function(db_name)
end
return setmetatable({
db_name = db_name,
conn_token = nil
db_name = db_name
}, connection_mt)
end

View File

@ -33,7 +33,7 @@ type State struct {
L *luajit.State // The Lua state
sandbox *Sandbox // Associated sandbox
index int // Index for debugging
inUse bool // Whether the state is currently in use
inUse atomic.Bool // Whether the state is currently in use
}
// Runner runs Lua scripts using a pool of Lua states
@ -115,14 +115,16 @@ func NewRunner(options ...RunnerOption) (*Runner, error) {
InitSQLite(runner.dataDir)
InitFS(runner.fsDir)
SetSQLitePoolSize(runner.poolSize)
// Initialize states and pool
runner.states = make([]*State, runner.poolSize)
runner.statePool = make(chan int, runner.poolSize)
// Create and initialize all states
if err := runner.initializeStates(); err != nil {
CleanupSQLite() // Clean up SQLite connections
runner.Close() // Clean up already created states
CleanupSQLite()
runner.Close()
return nil, err
}
@ -190,7 +192,6 @@ func (r *Runner) createState(index int) (*State, error) {
L: L,
sandbox: sb,
index: index,
inUse: false,
}, nil
}
@ -215,29 +216,26 @@ func (r *Runner) Execute(ctx context.Context, bytecode []byte, execCtx *Context,
// Got a state
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(5 * time.Second):
case <-time.After(1 * time.Second):
return nil, ErrTimeout
}
// Get the actual state
state := r.states[stateIndex]
if state == nil {
r.statePool <- stateIndex
return nil, ErrStateNotReady
}
// Mark state as in use
state.inUse = true
// Use atomic operations
state.inUse.Store(true)
// Ensure state is returned to pool when done
defer func() {
state.inUse = false
state.inUse.Store(false)
if r.isRunning.Load() {
select {
case r.statePool <- stateIndex:
// State returned to pool
default:
// Pool is full or closed
// Pool is full or closed, state will be cleaned up by Close()
}
}
}()
@ -267,21 +265,45 @@ func (r *Runner) Close() error {
r.isRunning.Store(false)
// Drain the state pool
// Drain all states from the pool
for {
select {
case <-r.statePool:
// Drain one state
default:
// Pool is empty
goto cleanup
goto waitForInUse
}
}
cleanup:
// Clean up all states
waitForInUse:
// Wait for in-use states to finish (with timeout)
timeout := time.Now().Add(10 * time.Second)
for {
allIdle := true
for _, state := range r.states {
if state != nil && state.inUse.Load() {
allIdle = false
break
}
}
if allIdle {
break
}
if time.Now().After(timeout) {
logger.Warning("Timeout waiting for states to finish during shutdown, forcing close")
break
}
time.Sleep(10 * time.Millisecond)
}
// Now safely close all states
for i, state := range r.states {
if state != nil {
if state.inUse.Load() {
logger.Warning("Force closing state %d that is still in use", i)
}
state.L.Cleanup()
state.L.Close()
r.states[i] = nil
@ -310,19 +332,40 @@ func (r *Runner) RefreshStates() error {
for {
select {
case <-r.statePool:
// Drain one state
default:
// Pool is empty
goto cleanup
goto waitForInUse
}
}
cleanup:
// Destroy all existing states
waitForInUse:
// Wait for in-use states to finish (with timeout)
timeout := time.Now().Add(10 * time.Second)
for {
allIdle := true
for _, state := range r.states {
if state != nil && state.inUse.Load() {
allIdle = false
break
}
}
if allIdle {
break
}
if time.Now().After(timeout) {
logger.Warning("Timeout waiting for states to finish, forcing refresh")
break
}
time.Sleep(10 * time.Millisecond)
}
// Now safely destroy all states
for i, state := range r.states {
if state != nil {
if state.inUse {
logger.Warning("Attempting to refresh state %d that is in use", i)
if state.inUse.Load() {
logger.Warning("Force closing state %d that is still in use", i)
}
state.L.Cleanup()
state.L.Close()
@ -367,7 +410,7 @@ func (r *Runner) RefreshModule(moduleName string) bool {
success := true
for _, state := range r.states {
if state == nil || state.inUse {
if state == nil || state.inUse.Load() {
continue
}
@ -403,7 +446,7 @@ func (r *Runner) GetActiveStateCount() int {
count := 0
for _, state := range r.states {
if state != nil && state.inUse {
if state != nil && state.inUse.Load() {
count++
}
}
@ -459,10 +502,10 @@ func (r *Runner) RunScriptFile(filePath string) (*Response, error) {
return nil, ErrStateNotReady
}
state.inUse = true
state.inUse.Store(true)
defer func() {
state.inUse = false
state.inUse.Store(false)
if r.isRunning.Load() {
select {
case r.statePool <- stateIndex:

View File

@ -2,8 +2,6 @@ package runner
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"path/filepath"
"strings"
@ -19,71 +17,29 @@ import (
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// DbPools maintains database connection pools
var (
dbPools = make(map[string]*sqlitex.Pool)
poolsMu sync.RWMutex
dataDir string
// Connection tracking
activeConns = make(map[string]*TrackedConn)
activeConnMu sync.RWMutex
connTimeout = 5 * time.Minute
poolSize = 8 // Default, will be set to match runner pool size
connTimeout = 5 * time.Second
)
// TrackedConn holds a connection with usage tracking
type TrackedConn struct {
Conn *sqlite.Conn
Pool *sqlitex.Pool
DBName string
LastUsed time.Time
}
// generateConnToken creates a unique token for connection tracking
func generateConnToken() string {
b := make([]byte, 8)
rand.Read(b)
return base64.URLEncoding.EncodeToString(b)
}
// InitSQLite initializes the SQLite subsystem
func InitSQLite(dir string) {
dataDir = dir
logger.Info("SQLite is g2g! %s", color.Apply(dir, color.Yellow))
// Start connection cleanup goroutine
go cleanupIdleConnections()
}
// cleanupIdleConnections periodically checks for and removes idle connections
func cleanupIdleConnections() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
activeConnMu.Lock()
for token, conn := range activeConns {
if conn.LastUsed.Add(connTimeout).Before(now) {
logger.Debug("Closing idle connection: %s (%s)", token, conn.DBName)
conn.Pool.Put(conn.Conn)
delete(activeConns, token)
}
}
activeConnMu.Unlock()
// SetSQLitePoolSize sets the pool size to match the runner pool size
func SetSQLitePoolSize(size int) {
if size > 0 {
poolSize = size
}
}
// CleanupSQLite closes all database connections
func CleanupSQLite() {
activeConnMu.Lock()
for token, conn := range activeConns {
conn.Pool.Put(conn.Conn)
delete(activeConns, token)
}
activeConnMu.Unlock()
poolsMu.Lock()
defer poolsMu.Unlock()
@ -123,74 +79,36 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
return pool, nil
}
// Create new pool
// Create new pool with proper size
dbPath := filepath.Join(dataDir, dbName+".db")
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{})
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
PoolSize: poolSize,
PrepareConn: func(conn *sqlite.Conn) error {
// Execute PRAGMA statements individually
pragmas := []string{
"PRAGMA journal_mode = WAL",
"PRAGMA synchronous = NORMAL",
"PRAGMA cache_size = 1000",
"PRAGMA foreign_keys = ON",
"PRAGMA temp_store = MEMORY",
}
for _, pragma := range pragmas {
if err := sqlitex.ExecuteTransient(conn, pragma, nil); err != nil {
return err
}
}
return nil
},
})
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
dbPools[dbName] = pool
logger.Debug("Created SQLite pool for %s (size: %d)", dbName, poolSize)
return pool, nil
}
// getConnection retrieves or creates a tracked connection
func getConnection(token, dbName string) (*TrackedConn, string, error) {
// If token is provided, try to get existing connection
if token != "" {
activeConnMu.RLock()
conn, exists := activeConns[token]
activeConnMu.RUnlock()
if exists {
conn.LastUsed = time.Now()
return conn, token, nil
}
}
// Token not provided or connection not found, create new
pool, err := getPool(dbName)
if err != nil {
return nil, "", err
}
conn, err := pool.Take(context.Background())
if err != nil {
return nil, "", err
}
// Generate new token
newToken := generateConnToken()
trackedConn := &TrackedConn{
Conn: conn,
Pool: pool,
DBName: dbName,
LastUsed: time.Now(),
}
activeConnMu.Lock()
activeConns[newToken] = trackedConn
activeConnMu.Unlock()
return trackedConn, newToken, nil
}
// releaseConnection releases a connection back to the pool
func releaseConnection(token string) bool {
activeConnMu.Lock()
defer activeConnMu.Unlock()
conn, exists := activeConns[token]
if !exists {
return false
}
conn.Pool.Put(conn.Conn)
delete(activeConns, token)
return true
}
// sqlQuery executes a SQL query and returns results
func sqlQuery(state *luajit.State) int {
// Get required parameters
@ -202,20 +120,23 @@ func sqlQuery(state *luajit.State) int {
dbName := state.ToString(1)
query := state.ToString(2)
// Get connection token (optional)
var connToken string
if state.GetTop() >= 4 && state.IsString(4) {
connToken = state.ToString(4)
}
// Get connection
trackedConn, newToken, err := getConnection(connToken, dbName)
// Get pool
pool, err := getPool(dbName)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
conn := trackedConn.Conn
// Get connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: connection timeout: %s", err.Error()))
return -1
}
defer pool.Put(conn)
// Create execution options
var execOpts sqlitex.ExecOptions
@ -223,65 +144,10 @@ func sqlQuery(state *luajit.State) int {
// Set up parameters if provided
if state.GetTop() >= 3 && !state.IsNil(3) {
if state.IsTable(3) {
params, err := state.ToTable(3)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: invalid parameters: %s", err.Error()))
if err := setupParams(state, 3, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
// Check for array-style params
if arr, ok := params[""]; ok {
if arrParams, ok := arr.([]any); ok {
execOpts.Args = arrParams
} else if floatArr, ok := arr.([]float64); ok {
args := make([]any, len(floatArr))
for i, v := range floatArr {
args[i] = v
}
execOpts.Args = args
}
} else {
// Named parameters
named := make(map[string]any, len(params))
for k, v := range params {
if len(k) > 0 && k[0] != ':' {
named[":"+k] = v
} else {
named[k] = v
}
}
execOpts.Named = named
}
} else {
// Positional parameters
count := state.GetTop() - 2
if state.IsString(4) {
count-- // Don't include connection token
}
args := make([]any, count)
for i := range count {
idx := i + 3
switch state.GetType(idx) {
case luajit.TypeNumber:
args[i] = state.ToNumber(idx)
case luajit.TypeString:
args[i] = state.ToString(idx)
case luajit.TypeBoolean:
args[i] = state.ToBoolean(idx)
case luajit.TypeNil:
args[i] = nil
default:
val, err := state.ToValue(idx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: invalid parameter %d: %s", i+1, err.Error()))
return -1
}
args[i] = val
}
}
execOpts.Args = args
}
}
// Set up result function
@ -300,8 +166,12 @@ func sqlQuery(state *luajit.State) int {
row[colName] = stmt.ColumnText(i)
case sqlite.TypeBlob:
blobSize := stmt.ColumnLen(i)
if blobSize > 0 {
buf := make([]byte, blobSize)
row[colName] = stmt.ColumnBytes(i, buf)
} else {
row[colName] = []byte{}
}
case sqlite.TypeNull:
row[colName] = nil
}
@ -327,10 +197,7 @@ func sqlQuery(state *luajit.State) int {
state.SetTable(-3)
}
// Return connection token
state.PushString(newToken)
return 2
return 1
}
// sqlExec executes a SQL statement without returning results
@ -344,58 +211,73 @@ func sqlExec(state *luajit.State) int {
dbName := state.ToString(1)
query := state.ToString(2)
// Get connection token (optional)
var connToken string
if state.GetTop() >= 4 && state.IsString(4) {
connToken = state.ToString(4)
}
// Get connection
trackedConn, newToken, err := getConnection(connToken, dbName)
// Get pool
pool, err := getPool(dbName)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
conn := trackedConn.Conn
// Get connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: connection timeout: %s", err.Error()))
return -1
}
defer pool.Put(conn)
// Check if parameters are provided
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
// Fast path for multi-statement scripts - use ExecScript
// Fast path for multi-statement scripts
if strings.Contains(query, ";") && !hasParams {
if err := sqlitex.ExecScript(conn, query); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
state.PushNumber(float64(conn.Changes()))
state.PushString(newToken)
return 2
return 1
}
// Fast path for simple queries with no parameters
if !hasParams {
// Use Execute for simple statements without parameters
if err := sqlitex.Execute(conn, query, nil); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
state.PushNumber(float64(conn.Changes()))
state.PushString(newToken)
return 2
return 1
}
// Create execution options for parameterized query
var execOpts sqlitex.ExecOptions
// Set up parameters
if state.IsTable(3) {
params, err := state.ToTable(3)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: invalid parameters: %s", err.Error()))
if err := setupParams(state, 3, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
// Execute with parameters
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
// Return affected rows
state.PushNumber(float64(conn.Changes()))
return 1
}
// setupParams configures execution options with parameters from Lua
func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error {
if state.IsTable(paramIndex) {
params, err := state.ToTable(paramIndex)
if err != nil {
return fmt.Errorf("invalid parameters: %w", err)
}
// Check for array-style params
if arr, ok := params[""]; ok {
if arrParams, ok := arr.([]any); ok {
@ -420,59 +302,111 @@ func sqlExec(state *luajit.State) int {
execOpts.Named = named
}
} else {
// Positional parameters
// Positional parameters from stack
count := state.GetTop() - 2
if state.IsString(4) {
count-- // Don't include connection token
}
args := make([]any, count)
for i := range count {
idx := i + 3
switch state.GetType(idx) {
case luajit.TypeNumber:
args[i] = state.ToNumber(idx)
case luajit.TypeString:
args[i] = state.ToString(idx)
case luajit.TypeBoolean:
args[i] = state.ToBoolean(idx)
case luajit.TypeNil:
args[i] = nil
default:
val, err := state.ToValue(idx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: invalid parameter %d: %s", i+1, err.Error()))
return -1
return fmt.Errorf("invalid parameter %d: %w", i+1, err)
}
args[i] = val
}
}
execOpts.Args = args
}
// Execute with parameters
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
// Return affected rows and connection token
state.PushNumber(float64(conn.Changes()))
state.PushString(newToken)
return 2
return nil
}
// sqlClose releases a connection back to the pool
func sqlClose(state *luajit.State) int {
if state.GetTop() < 1 || !state.IsString(1) {
state.PushString("sqlite.close: requires connection token")
// sqlGetOne executes a query and returns only the first row
func sqlGetOne(state *luajit.State) int {
// Get required parameters
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
state.PushString("sqlite.get_one: requires database name and query")
return -1
}
token := state.ToString(1)
if releaseConnection(token) {
state.PushBoolean(true)
dbName := state.ToString(1)
query := state.ToString(2)
// Get pool
pool, err := getPool(dbName)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
}
// Get connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
conn, err := pool.Take(ctx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: connection timeout: %s", err.Error()))
return -1
}
defer pool.Put(conn)
// Create execution options
var execOpts sqlitex.ExecOptions
var result map[string]any
// Set up parameters if provided
if state.GetTop() >= 3 && !state.IsNil(3) {
if err := setupParams(state, 3, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
}
}
// Set up result function to get only first row
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
if result != nil {
return nil // Already got first row
}
result = make(map[string]any)
colCount := stmt.ColumnCount()
for i := range colCount {
colName := stmt.ColumnName(i)
switch stmt.ColumnType(i) {
case sqlite.TypeInteger:
result[colName] = stmt.ColumnInt64(i)
case sqlite.TypeFloat:
result[colName] = stmt.ColumnFloat(i)
case sqlite.TypeText:
result[colName] = stmt.ColumnText(i)
case sqlite.TypeBlob:
blobSize := stmt.ColumnLen(i)
if blobSize > 0 {
buf := make([]byte, blobSize)
result[colName] = stmt.ColumnBytes(i, buf)
} else {
state.PushBoolean(false)
result[colName] = []byte{}
}
case sqlite.TypeNull:
result[colName] = nil
}
}
return nil
}
// Execute query
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
}
// Return result or nil if no rows
if result == nil {
state.PushNil()
} else {
if err := state.PushTable(result); err != nil {
state.PushString(fmt.Sprintf("sqlite.get_one: %s", err.Error()))
return -1
}
}
return 1
@ -486,7 +420,7 @@ func RegisterSQLiteFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__sqlite_exec", sqlExec); err != nil {
return err
}
if err := state.RegisterGoFunction("__sqlite_close", sqlClose); err != nil {
if err := state.RegisterGoFunction("__sqlite_get_one", sqlGetOne); err != nil {
return err
}
return nil