fix lots of luajit api regressions
This commit is contained in:
parent
1ad3059ff0
commit
61f66d6594
2
go.mod
2
go.mod
@ -5,7 +5,7 @@ go 1.24.1
|
|||||||
require (
|
require (
|
||||||
git.sharkk.net/Go/LRU v1.0.0
|
git.sharkk.net/Go/LRU v1.0.0
|
||||||
git.sharkk.net/Sharkk/Fin v1.2.0
|
git.sharkk.net/Sharkk/Fin v1.2.0
|
||||||
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.0
|
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.1
|
||||||
github.com/VictoriaMetrics/fastcache v1.12.4
|
github.com/VictoriaMetrics/fastcache v1.12.4
|
||||||
github.com/alexedwards/argon2id v1.0.0
|
github.com/alexedwards/argon2id v1.0.0
|
||||||
github.com/deneonet/benc v1.1.8
|
github.com/deneonet/benc v1.1.8
|
||||||
|
4
go.sum
4
go.sum
@ -2,8 +2,8 @@ git.sharkk.net/Go/LRU v1.0.0 h1:/KqdRVhHldi23aVfQZ4ss6vhCWZqA3vFiQyf1MJPpQc=
|
|||||||
git.sharkk.net/Go/LRU v1.0.0/go.mod h1:8tdTyl85mss9a+KKwo+Wj9gKHOizhfLfpJhz1ltYz50=
|
git.sharkk.net/Go/LRU v1.0.0/go.mod h1:8tdTyl85mss9a+KKwo+Wj9gKHOizhfLfpJhz1ltYz50=
|
||||||
git.sharkk.net/Sharkk/Fin v1.2.0 h1:axhme8vHRYoaB3us7PNfXzXxKOxhpS5BMuNpN8ESe6U=
|
git.sharkk.net/Sharkk/Fin v1.2.0 h1:axhme8vHRYoaB3us7PNfXzXxKOxhpS5BMuNpN8ESe6U=
|
||||||
git.sharkk.net/Sharkk/Fin v1.2.0/go.mod h1:ca0Ej9yCM/vHh1o3YMvBZspme3EtbwoEL2UXN5UPXMo=
|
git.sharkk.net/Sharkk/Fin v1.2.0/go.mod h1:ca0Ej9yCM/vHh1o3YMvBZspme3EtbwoEL2UXN5UPXMo=
|
||||||
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.0 h1:WzIKbNIoP+P56n7EmkD9V1QZJUNMbTm3cJj2jc5qUfI=
|
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.1 h1:e9rby0xJs8m2SAPv0di/LplDok88UyjcNjKu8S4d1BY=
|
||||||
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.0/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8=
|
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.1/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8=
|
||||||
github.com/VictoriaMetrics/fastcache v1.12.4 h1:2xvmwZBW+9QtHsXggfzAZRs1FZWCsBs8QDg22bMidf0=
|
github.com/VictoriaMetrics/fastcache v1.12.4 h1:2xvmwZBW+9QtHsXggfzAZRs1FZWCsBs8QDg22bMidf0=
|
||||||
github.com/VictoriaMetrics/fastcache v1.12.4/go.mod h1:K+JGPBn0sueFlLjZ8rcVM0cKkWKNElKyQXmw57QOoYI=
|
github.com/VictoriaMetrics/fastcache v1.12.4/go.mod h1:K+JGPBn0sueFlLjZ8rcVM0cKkWKNElKyQXmw57QOoYI=
|
||||||
github.com/alexedwards/argon2id v1.0.0 h1:wJzDx66hqWX7siL/SRUmgz3F8YMrd/nfX/xHHcQQP0w=
|
github.com/alexedwards/argon2id v1.0.0 h1:wJzDx66hqWX7siL/SRUmgz3F8YMrd/nfX/xHHcQQP0w=
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
--[[
|
--[[
|
||||||
sandbox.lua
|
sandbox.lua - Rewritten with global context storage
|
||||||
]]--
|
]]--
|
||||||
|
|
||||||
__http_response = {}
|
__http_response = {}
|
||||||
@ -8,6 +8,9 @@ __module_bytecode = {}
|
|||||||
__ready_modules = {}
|
__ready_modules = {}
|
||||||
__EXIT_SENTINEL = {} -- Unique object for exit identification
|
__EXIT_SENTINEL = {} -- Unique object for exit identification
|
||||||
|
|
||||||
|
-- Global context storage for reliable access
|
||||||
|
local _current_ctx = nil
|
||||||
|
|
||||||
-- ======================================================================
|
-- ======================================================================
|
||||||
-- CORE SANDBOX FUNCTIONALITY
|
-- CORE SANDBOX FUNCTIONALITY
|
||||||
-- ======================================================================
|
-- ======================================================================
|
||||||
@ -38,17 +41,20 @@ end
|
|||||||
-- Execute script with clean environment
|
-- Execute script with clean environment
|
||||||
function __execute_script(fn, ctx)
|
function __execute_script(fn, ctx)
|
||||||
__http_response = nil
|
__http_response = nil
|
||||||
|
_current_ctx = ctx -- Store globally for function access
|
||||||
|
|
||||||
local env = __create_env(ctx)
|
local env = __create_env(ctx)
|
||||||
env.exit = exit
|
env.exit = exit
|
||||||
setfenv(fn, env)
|
setfenv(fn, env)
|
||||||
|
|
||||||
local ok, result = pcall(fn)
|
local ok, result = pcall(fn)
|
||||||
|
|
||||||
|
_current_ctx = nil -- Clean up after execution
|
||||||
|
|
||||||
if not ok then
|
if not ok then
|
||||||
if result == __EXIT_SENTINEL then
|
if result == __EXIT_SENTINEL then
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
error(result, 0)
|
error(result, 0)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -258,14 +264,13 @@ function cookie_get(name)
|
|||||||
error("cookie_get: name must be a string", 2)
|
error("cookie_get: name must be a string", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
local env = getfenv(2)
|
if _current_ctx then
|
||||||
|
if _current_ctx.cookies then
|
||||||
if env.ctx and env.ctx.cookies then
|
return _current_ctx.cookies[name]
|
||||||
return env.ctx.cookies[name]
|
end
|
||||||
|
if _current_ctx._request_cookies then
|
||||||
|
return _current_ctx._request_cookies[name]
|
||||||
end
|
end
|
||||||
|
|
||||||
if env.ctx and env.ctx._request_cookies then
|
|
||||||
return env.ctx._request_cookies[name]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -289,10 +294,8 @@ function session_get(key)
|
|||||||
error("session_get: key must be a string", 2)
|
error("session_get: key must be a string", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
local env = getfenv(2)
|
if _current_ctx and _current_ctx.session and _current_ctx.session.data then
|
||||||
|
return _current_ctx.session.data[key]
|
||||||
if env.ctx and env.ctx.session and env.ctx.session.data then
|
|
||||||
return env.ctx.session.data[key]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -302,7 +305,7 @@ function session_set(key, value)
|
|||||||
if type(key) ~= "string" then
|
if type(key) ~= "string" then
|
||||||
error("session_set: key must be a string", 2)
|
error("session_set: key must be a string", 2)
|
||||||
end
|
end
|
||||||
if type(value) == nil then
|
if value == nil then
|
||||||
error("session_set: value cannot be nil", 2)
|
error("session_set: value cannot be nil", 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -310,30 +313,29 @@ function session_set(key, value)
|
|||||||
resp.session = resp.session or {}
|
resp.session = resp.session or {}
|
||||||
resp.session[key] = value
|
resp.session[key] = value
|
||||||
|
|
||||||
local env = getfenv(2)
|
-- Update current context session data
|
||||||
if env.ctx and env.ctx.session and env.ctx.session.data then
|
if _current_ctx and _current_ctx.session and _current_ctx.session.data then
|
||||||
env.ctx.session.data[key] = value
|
_current_ctx.session.data[key] = value
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function session_id()
|
function session_id()
|
||||||
local env = getfenv(2)
|
if _current_ctx and _current_ctx.session then
|
||||||
|
return _current_ctx.session.id
|
||||||
if env.ctx and env.ctx.session then
|
|
||||||
return env.ctx.session.id
|
|
||||||
end
|
end
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
end
|
end
|
||||||
|
|
||||||
function session_get_all()
|
function session_get_all()
|
||||||
local env = getfenv(2)
|
if _current_ctx and _current_ctx.session and _current_ctx.session.data then
|
||||||
|
-- Return a copy to prevent modification
|
||||||
if env.ctx and env.ctx.session then
|
local copy = {}
|
||||||
return env.ctx.session.data
|
for k, v in pairs(_current_ctx.session.data) do
|
||||||
|
copy[k] = v
|
||||||
end
|
end
|
||||||
|
return copy
|
||||||
return nil
|
end
|
||||||
|
return {}
|
||||||
end
|
end
|
||||||
|
|
||||||
function session_delete(key)
|
function session_delete(key)
|
||||||
@ -345,17 +347,16 @@ function session_delete(key)
|
|||||||
resp.session = resp.session or {}
|
resp.session = resp.session or {}
|
||||||
resp.session[key] = "__SESSION_DELETE_MARKER__"
|
resp.session[key] = "__SESSION_DELETE_MARKER__"
|
||||||
|
|
||||||
local env = getfenv(2)
|
-- Update current context
|
||||||
if env.ctx and env.ctx.session and env.ctx.session.data then
|
if _current_ctx and _current_ctx.session and _current_ctx.session.data then
|
||||||
env.ctx.session.data[key] = nil
|
_current_ctx.session.data[key] = nil
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function session_clear()
|
function session_clear()
|
||||||
local env = getfenv(2)
|
if _current_ctx and _current_ctx.session and _current_ctx.session.data then
|
||||||
if env.ctx and env.ctx.session and env.ctx.session.data then
|
for k, _ in pairs(_current_ctx.session.data) do
|
||||||
for k, _ in pairs(env.ctx.session.data) do
|
_current_ctx.session.data[k] = nil
|
||||||
env.ctx.session.data[k] = nil
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -384,11 +385,7 @@ function csrf_field()
|
|||||||
end
|
end
|
||||||
|
|
||||||
function csrf_validate()
|
function csrf_validate()
|
||||||
local env = getfenv(2)
|
local token = session_get("_csrf_token")
|
||||||
local token = false
|
|
||||||
if env.ctx and env.ctx.session and env.ctx.session.data then
|
|
||||||
token = env.ctx.session.data["_csrf_token"]
|
|
||||||
end
|
|
||||||
|
|
||||||
if not token then
|
if not token then
|
||||||
http_set_status(403)
|
http_set_status(403)
|
||||||
@ -397,13 +394,13 @@ function csrf_validate()
|
|||||||
end
|
end
|
||||||
|
|
||||||
local request_token = nil
|
local request_token = nil
|
||||||
if env.ctx and env.ctx.form then
|
if _current_ctx and _current_ctx.form then
|
||||||
request_token = env.ctx.form._csrf_token
|
request_token = _current_ctx.form._csrf_token
|
||||||
end
|
end
|
||||||
|
|
||||||
if not request_token and env.ctx and env.ctx._request_headers then
|
if not request_token and _current_ctx and _current_ctx._request_headers then
|
||||||
request_token = env.ctx._request_headers["x-csrf-token"] or
|
request_token = _current_ctx._request_headers["x-csrf-token"] or
|
||||||
env.ctx._request_headers["csrf-token"]
|
_current_ctx._request_headers["csrf-token"]
|
||||||
end
|
end
|
||||||
|
|
||||||
if not request_token or request_token ~= token then
|
if not request_token or request_token ~= token then
|
||||||
|
@ -177,12 +177,15 @@ func extractHTTPResponseData(state *luajit.State, response *Response) {
|
|||||||
|
|
||||||
// Extract headers using ForEachTableKV
|
// Extract headers using ForEachTableKV
|
||||||
if headerTable, ok := state.GetFieldTable(-1, "headers"); ok {
|
if headerTable, ok := state.GetFieldTable(-1, "headers"); ok {
|
||||||
if headers, ok := headerTable.(map[string]any); ok {
|
switch headers := headerTable.(type) {
|
||||||
|
case map[string]any:
|
||||||
for k, v := range headers {
|
for k, v := range headers {
|
||||||
if str, ok := v.(string); ok {
|
if str, ok := v.(string); ok {
|
||||||
response.Headers[k] = str
|
response.Headers[k] = str
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case map[string]string:
|
||||||
|
maps.Copy(response.Headers, headers)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -207,8 +210,19 @@ func extractHTTPResponseData(state *luajit.State, response *Response) {
|
|||||||
|
|
||||||
// Extract session data
|
// Extract session data
|
||||||
if session, ok := state.GetFieldTable(-1, "session"); ok {
|
if session, ok := state.GetFieldTable(-1, "session"); ok {
|
||||||
if sessMap, ok := session.(map[string]any); ok {
|
switch sessMap := session.(type) {
|
||||||
|
case map[string]any:
|
||||||
maps.Copy(response.SessionData, sessMap)
|
maps.Copy(response.SessionData, sessMap)
|
||||||
|
case map[string]string:
|
||||||
|
for k, v := range sessMap {
|
||||||
|
response.SessionData[k] = v
|
||||||
|
}
|
||||||
|
case map[string]int:
|
||||||
|
for k, v := range sessMap {
|
||||||
|
response.SessionData[k] = v
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
logger.Debugf("Unexpected session type: %T", session)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
111
runner/sqlite.go
111
runner/sqlite.go
@ -21,24 +21,21 @@ var (
|
|||||||
dbPools = make(map[string]*sqlitex.Pool)
|
dbPools = make(map[string]*sqlitex.Pool)
|
||||||
poolsMu sync.RWMutex
|
poolsMu sync.RWMutex
|
||||||
dataDir string
|
dataDir string
|
||||||
poolSize = 8 // Default, will be set to match runner pool size
|
poolSize = 8
|
||||||
connTimeout = 5 * time.Second
|
connTimeout = 5 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// InitSQLite initializes the SQLite subsystem
|
|
||||||
func InitSQLite(dir string) {
|
func InitSQLite(dir string) {
|
||||||
dataDir = dir
|
dataDir = dir
|
||||||
logger.Infof("SQLite is g2g! %s", color.Yellow(dir))
|
logger.Infof("SQLite is g2g! %s", color.Yellow(dir))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSQLitePoolSize sets the pool size to match the runner pool size
|
|
||||||
func SetSQLitePoolSize(size int) {
|
func SetSQLitePoolSize(size int) {
|
||||||
if size > 0 {
|
if size > 0 {
|
||||||
poolSize = size
|
poolSize = size
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanupSQLite closes all database connections
|
|
||||||
func CleanupSQLite() {
|
func CleanupSQLite() {
|
||||||
poolsMu.Lock()
|
poolsMu.Lock()
|
||||||
defer poolsMu.Unlock()
|
defer poolsMu.Unlock()
|
||||||
@ -53,15 +50,12 @@ func CleanupSQLite() {
|
|||||||
logger.Debugf("SQLite connections closed")
|
logger.Debugf("SQLite connections closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// getPool returns a connection pool for the database
|
|
||||||
func getPool(dbName string) (*sqlitex.Pool, error) {
|
func getPool(dbName string) (*sqlitex.Pool, error) {
|
||||||
// Validate database name
|
|
||||||
dbName = filepath.Base(dbName)
|
dbName = filepath.Base(dbName)
|
||||||
if dbName == "" || dbName[0] == '.' {
|
if dbName == "" || dbName[0] == '.' {
|
||||||
return nil, fmt.Errorf("invalid database name")
|
return nil, fmt.Errorf("invalid database name")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for existing pool
|
|
||||||
poolsMu.RLock()
|
poolsMu.RLock()
|
||||||
pool, exists := dbPools[dbName]
|
pool, exists := dbPools[dbName]
|
||||||
if exists {
|
if exists {
|
||||||
@ -70,21 +64,17 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
|
|||||||
}
|
}
|
||||||
poolsMu.RUnlock()
|
poolsMu.RUnlock()
|
||||||
|
|
||||||
// Create new pool under write lock
|
|
||||||
poolsMu.Lock()
|
poolsMu.Lock()
|
||||||
defer poolsMu.Unlock()
|
defer poolsMu.Unlock()
|
||||||
|
|
||||||
// Double-check if a pool was created while waiting for lock
|
|
||||||
if pool, exists = dbPools[dbName]; exists {
|
if pool, exists = dbPools[dbName]; exists {
|
||||||
return pool, nil
|
return pool, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new pool with proper size
|
|
||||||
dbPath := filepath.Join(dataDir, dbName+".db")
|
dbPath := filepath.Join(dataDir, dbName+".db")
|
||||||
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
|
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
|
||||||
PoolSize: poolSize,
|
PoolSize: poolSize,
|
||||||
PrepareConn: func(conn *sqlite.Conn) error {
|
PrepareConn: func(conn *sqlite.Conn) error {
|
||||||
// Execute PRAGMA statements individually
|
|
||||||
pragmas := []string{
|
pragmas := []string{
|
||||||
"PRAGMA journal_mode = WAL",
|
"PRAGMA journal_mode = WAL",
|
||||||
"PRAGMA synchronous = NORMAL",
|
"PRAGMA synchronous = NORMAL",
|
||||||
@ -109,7 +99,6 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
|
|||||||
return pool, nil
|
return pool, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// sqlQuery executes a SQL query and returns results
|
|
||||||
func sqlQuery(state *luajit.State) int {
|
func sqlQuery(state *luajit.State) int {
|
||||||
if err := state.CheckMinArgs(2); err != nil {
|
if err := state.CheckMinArgs(2); err != nil {
|
||||||
return state.PushError("sqlite.query: %v", err)
|
return state.PushError("sqlite.query: %v", err)
|
||||||
@ -125,13 +114,11 @@ func sqlQuery(state *luajit.State) int {
|
|||||||
return state.PushError("sqlite.query: query must be string")
|
return state.PushError("sqlite.query: query must be string")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get pool
|
|
||||||
pool, err := getPool(dbName)
|
pool, err := getPool(dbName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return state.PushError("sqlite.query: %v", err)
|
return state.PushError("sqlite.query: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get connection with timeout
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@ -141,18 +128,15 @@ func sqlQuery(state *luajit.State) int {
|
|||||||
}
|
}
|
||||||
defer pool.Put(conn)
|
defer pool.Put(conn)
|
||||||
|
|
||||||
// Create execution options
|
|
||||||
var execOpts sqlitex.ExecOptions
|
var execOpts sqlitex.ExecOptions
|
||||||
rows := make([]map[string]any, 0, 16)
|
rows := make([]any, 0, 16)
|
||||||
|
|
||||||
// Set up parameters if provided
|
|
||||||
if state.GetTop() >= 3 && !state.IsNil(3) {
|
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||||
if err := setupParams(state, 3, &execOpts); err != nil {
|
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||||
return state.PushError("sqlite.query: %v", err)
|
return state.PushError("sqlite.query: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up result function
|
|
||||||
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
||||||
row := make(map[string]any)
|
row := make(map[string]any)
|
||||||
colCount := stmt.ColumnCount()
|
colCount := stmt.ColumnCount()
|
||||||
@ -182,12 +166,10 @@ func sqlQuery(state *luajit.State) int {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute query
|
|
||||||
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||||
return state.PushError("sqlite.query: %v", err)
|
return state.PushError("sqlite.query: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create result using specific map type and PushValue
|
|
||||||
if err := state.PushValue(rows); err != nil {
|
if err := state.PushValue(rows); err != nil {
|
||||||
return state.PushError("sqlite.query: %v", err)
|
return state.PushError("sqlite.query: %v", err)
|
||||||
}
|
}
|
||||||
@ -195,7 +177,6 @@ func sqlQuery(state *luajit.State) int {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// sqlExec executes a SQL statement without returning results
|
|
||||||
func sqlExec(state *luajit.State) int {
|
func sqlExec(state *luajit.State) int {
|
||||||
if err := state.CheckMinArgs(2); err != nil {
|
if err := state.CheckMinArgs(2); err != nil {
|
||||||
return state.PushError("sqlite.exec: %v", err)
|
return state.PushError("sqlite.exec: %v", err)
|
||||||
@ -211,13 +192,11 @@ func sqlExec(state *luajit.State) int {
|
|||||||
return state.PushError("sqlite.exec: query must be string")
|
return state.PushError("sqlite.exec: query must be string")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get pool
|
|
||||||
pool, err := getPool(dbName)
|
pool, err := getPool(dbName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return state.PushError("sqlite.exec: %v", err)
|
return state.PushError("sqlite.exec: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get connection with timeout
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@ -227,10 +206,8 @@ func sqlExec(state *luajit.State) int {
|
|||||||
}
|
}
|
||||||
defer pool.Put(conn)
|
defer pool.Put(conn)
|
||||||
|
|
||||||
// Check if parameters are provided
|
|
||||||
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
|
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
|
||||||
|
|
||||||
// Fast path for multi-statement scripts
|
|
||||||
if strings.Contains(query, ";") && !hasParams {
|
if strings.Contains(query, ";") && !hasParams {
|
||||||
if err := sqlitex.ExecScript(conn, query); err != nil {
|
if err := sqlitex.ExecScript(conn, query); err != nil {
|
||||||
return state.PushError("sqlite.exec: %v", err)
|
return state.PushError("sqlite.exec: %v", err)
|
||||||
@ -239,7 +216,6 @@ func sqlExec(state *luajit.State) int {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fast path for simple queries with no parameters
|
|
||||||
if !hasParams {
|
if !hasParams {
|
||||||
if err := sqlitex.Execute(conn, query, nil); err != nil {
|
if err := sqlitex.Execute(conn, query, nil); err != nil {
|
||||||
return state.PushError("sqlite.exec: %v", err)
|
return state.PushError("sqlite.exec: %v", err)
|
||||||
@ -248,23 +224,19 @@ func sqlExec(state *luajit.State) int {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create execution options for parameterized query
|
|
||||||
var execOpts sqlitex.ExecOptions
|
var execOpts sqlitex.ExecOptions
|
||||||
if err := setupParams(state, 3, &execOpts); err != nil {
|
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||||
return state.PushError("sqlite.exec: %v", err)
|
return state.PushError("sqlite.exec: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute with parameters
|
|
||||||
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||||
return state.PushError("sqlite.exec: %v", err)
|
return state.PushError("sqlite.exec: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return affected rows
|
|
||||||
state.PushNumber(float64(conn.Changes()))
|
state.PushNumber(float64(conn.Changes()))
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupParams configures execution options with parameters from Lua
|
|
||||||
func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error {
|
func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error {
|
||||||
if state.IsTable(paramIndex) {
|
if state.IsTable(paramIndex) {
|
||||||
paramsAny, err := state.SafeToTable(paramIndex)
|
paramsAny, err := state.SafeToTable(paramIndex)
|
||||||
@ -272,13 +244,8 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti
|
|||||||
return fmt.Errorf("invalid parameters: %w", err)
|
return fmt.Errorf("invalid parameters: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Type assert to map[string]any
|
switch params := paramsAny.(type) {
|
||||||
params, ok := paramsAny.(map[string]any)
|
case map[string]any:
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("parameters must be a table")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for array-style params
|
|
||||||
if arr, ok := params[""]; ok {
|
if arr, ok := params[""]; ok {
|
||||||
if arrParams, ok := arr.([]any); ok {
|
if arrParams, ok := arr.([]any); ok {
|
||||||
execOpts.Args = arrParams
|
execOpts.Args = arrParams
|
||||||
@ -290,7 +257,6 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti
|
|||||||
execOpts.Args = args
|
execOpts.Args = args
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Named parameters
|
|
||||||
named := make(map[string]any, len(params))
|
named := make(map[string]any, len(params))
|
||||||
for k, v := range params {
|
for k, v := range params {
|
||||||
if len(k) > 0 && k[0] != ':' {
|
if len(k) > 0 && k[0] != ':' {
|
||||||
@ -301,8 +267,64 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti
|
|||||||
}
|
}
|
||||||
execOpts.Named = named
|
execOpts.Named = named
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case map[string]string:
|
||||||
|
named := make(map[string]any, len(params))
|
||||||
|
for k, v := range params {
|
||||||
|
if len(k) > 0 && k[0] != ':' {
|
||||||
|
named[":"+k] = v
|
||||||
|
} else {
|
||||||
|
named[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
execOpts.Named = named
|
||||||
|
|
||||||
|
case map[string]int:
|
||||||
|
named := make(map[string]any, len(params))
|
||||||
|
for k, v := range params {
|
||||||
|
if len(k) > 0 && k[0] != ':' {
|
||||||
|
named[":"+k] = v
|
||||||
|
} else {
|
||||||
|
named[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
execOpts.Named = named
|
||||||
|
|
||||||
|
case map[int]any:
|
||||||
|
named := make(map[string]any, len(params))
|
||||||
|
for k, v := range params {
|
||||||
|
named[fmt.Sprintf(":%d", k)] = v
|
||||||
|
}
|
||||||
|
execOpts.Named = named
|
||||||
|
|
||||||
|
case []any:
|
||||||
|
execOpts.Args = params
|
||||||
|
|
||||||
|
case []string:
|
||||||
|
args := make([]any, len(params))
|
||||||
|
for i, v := range params {
|
||||||
|
args[i] = v
|
||||||
|
}
|
||||||
|
execOpts.Args = args
|
||||||
|
|
||||||
|
case []int:
|
||||||
|
args := make([]any, len(params))
|
||||||
|
for i, v := range params {
|
||||||
|
args[i] = v
|
||||||
|
}
|
||||||
|
execOpts.Args = args
|
||||||
|
|
||||||
|
case []float64:
|
||||||
|
args := make([]any, len(params))
|
||||||
|
for i, v := range params {
|
||||||
|
args[i] = v
|
||||||
|
}
|
||||||
|
execOpts.Args = args
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported parameter type: %T", params)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// Positional parameters from stack
|
|
||||||
count := state.GetTop() - 2
|
count := state.GetTop() - 2
|
||||||
args := make([]any, count)
|
args := make([]any, count)
|
||||||
for i := range count {
|
for i := range count {
|
||||||
@ -319,7 +341,6 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// sqlGetOne executes a query and returns only the first row
|
|
||||||
func sqlGetOne(state *luajit.State) int {
|
func sqlGetOne(state *luajit.State) int {
|
||||||
if err := state.CheckMinArgs(2); err != nil {
|
if err := state.CheckMinArgs(2); err != nil {
|
||||||
return state.PushError("sqlite.get_one: %v", err)
|
return state.PushError("sqlite.get_one: %v", err)
|
||||||
@ -335,13 +356,11 @@ func sqlGetOne(state *luajit.State) int {
|
|||||||
return state.PushError("sqlite.get_one: query must be string")
|
return state.PushError("sqlite.get_one: query must be string")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get pool
|
|
||||||
pool, err := getPool(dbName)
|
pool, err := getPool(dbName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return state.PushError("sqlite.get_one: %v", err)
|
return state.PushError("sqlite.get_one: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get connection with timeout
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@ -351,21 +370,18 @@ func sqlGetOne(state *luajit.State) int {
|
|||||||
}
|
}
|
||||||
defer pool.Put(conn)
|
defer pool.Put(conn)
|
||||||
|
|
||||||
// Create execution options
|
|
||||||
var execOpts sqlitex.ExecOptions
|
var execOpts sqlitex.ExecOptions
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
|
|
||||||
// Set up parameters if provided
|
|
||||||
if state.GetTop() >= 3 && !state.IsNil(3) {
|
if state.GetTop() >= 3 && !state.IsNil(3) {
|
||||||
if err := setupParams(state, 3, &execOpts); err != nil {
|
if err := setupParams(state, 3, &execOpts); err != nil {
|
||||||
return state.PushError("sqlite.get_one: %v", err)
|
return state.PushError("sqlite.get_one: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up result function to get only first row
|
|
||||||
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
||||||
if result != nil {
|
if result != nil {
|
||||||
return nil // Already got first row
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
result = make(map[string]any)
|
result = make(map[string]any)
|
||||||
@ -395,12 +411,10 @@ func sqlGetOne(state *luajit.State) int {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute query
|
|
||||||
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||||
return state.PushError("sqlite.get_one: %v", err)
|
return state.PushError("sqlite.get_one: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return result or nil if no rows
|
|
||||||
if result == nil {
|
if result == nil {
|
||||||
state.PushNil()
|
state.PushNil()
|
||||||
} else {
|
} else {
|
||||||
@ -412,7 +426,6 @@ func sqlGetOne(state *luajit.State) int {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterSQLiteFunctions registers SQLite functions with the Lua state
|
|
||||||
func RegisterSQLiteFunctions(state *luajit.State) error {
|
func RegisterSQLiteFunctions(state *luajit.State) error {
|
||||||
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
|
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
|
Loading…
x
Reference in New Issue
Block a user