fix lots of luajit api regressions

This commit is contained in:
Sky Johnson 2025-06-02 22:18:54 -05:00
parent 1ad3059ff0
commit 61f66d6594
5 changed files with 134 additions and 110 deletions

2
go.mod
View File

@ -5,7 +5,7 @@ go 1.24.1
require (
git.sharkk.net/Go/LRU v1.0.0
git.sharkk.net/Sharkk/Fin v1.2.0
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.0
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.1
github.com/VictoriaMetrics/fastcache v1.12.4
github.com/alexedwards/argon2id v1.0.0
github.com/deneonet/benc v1.1.8

4
go.sum
View File

@ -2,8 +2,8 @@ git.sharkk.net/Go/LRU v1.0.0 h1:/KqdRVhHldi23aVfQZ4ss6vhCWZqA3vFiQyf1MJPpQc=
git.sharkk.net/Go/LRU v1.0.0/go.mod h1:8tdTyl85mss9a+KKwo+Wj9gKHOizhfLfpJhz1ltYz50=
git.sharkk.net/Sharkk/Fin v1.2.0 h1:axhme8vHRYoaB3us7PNfXzXxKOxhpS5BMuNpN8ESe6U=
git.sharkk.net/Sharkk/Fin v1.2.0/go.mod h1:ca0Ej9yCM/vHh1o3YMvBZspme3EtbwoEL2UXN5UPXMo=
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.0 h1:WzIKbNIoP+P56n7EmkD9V1QZJUNMbTm3cJj2jc5qUfI=
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.0/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8=
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.1 h1:e9rby0xJs8m2SAPv0di/LplDok88UyjcNjKu8S4d1BY=
git.sharkk.net/Sky/LuaJIT-to-Go v0.5.1/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8=
github.com/VictoriaMetrics/fastcache v1.12.4 h1:2xvmwZBW+9QtHsXggfzAZRs1FZWCsBs8QDg22bMidf0=
github.com/VictoriaMetrics/fastcache v1.12.4/go.mod h1:K+JGPBn0sueFlLjZ8rcVM0cKkWKNElKyQXmw57QOoYI=
github.com/alexedwards/argon2id v1.0.0 h1:wJzDx66hqWX7siL/SRUmgz3F8YMrd/nfX/xHHcQQP0w=

View File

@ -1,5 +1,5 @@
--[[
sandbox.lua
sandbox.lua - Rewritten with global context storage
]]--
__http_response = {}
@ -8,6 +8,9 @@ __module_bytecode = {}
__ready_modules = {}
__EXIT_SENTINEL = {} -- Unique object for exit identification
-- Global context storage for reliable access
local _current_ctx = nil
-- ======================================================================
-- CORE SANDBOX FUNCTIONALITY
-- ======================================================================
@ -38,17 +41,20 @@ end
-- Execute script with clean environment
function __execute_script(fn, ctx)
__http_response = nil
_current_ctx = ctx -- Store globally for function access
local env = __create_env(ctx)
env.exit = exit
setfenv(fn, env)
local ok, result = pcall(fn)
_current_ctx = nil -- Clean up after execution
if not ok then
if result == __EXIT_SENTINEL then
return
end
error(result, 0)
end
@ -258,14 +264,13 @@ function cookie_get(name)
error("cookie_get: name must be a string", 2)
end
local env = getfenv(2)
if env.ctx and env.ctx.cookies then
return env.ctx.cookies[name]
if _current_ctx then
if _current_ctx.cookies then
return _current_ctx.cookies[name]
end
if _current_ctx._request_cookies then
return _current_ctx._request_cookies[name]
end
if env.ctx and env.ctx._request_cookies then
return env.ctx._request_cookies[name]
end
return nil
@ -289,10 +294,8 @@ function session_get(key)
error("session_get: key must be a string", 2)
end
local env = getfenv(2)
if env.ctx and env.ctx.session and env.ctx.session.data then
return env.ctx.session.data[key]
if _current_ctx and _current_ctx.session and _current_ctx.session.data then
return _current_ctx.session.data[key]
end
return nil
@ -302,7 +305,7 @@ function session_set(key, value)
if type(key) ~= "string" then
error("session_set: key must be a string", 2)
end
if type(value) == nil then
if value == nil then
error("session_set: value cannot be nil", 2)
end
@ -310,30 +313,29 @@ function session_set(key, value)
resp.session = resp.session or {}
resp.session[key] = value
local env = getfenv(2)
if env.ctx and env.ctx.session and env.ctx.session.data then
env.ctx.session.data[key] = value
-- Update current context session data
if _current_ctx and _current_ctx.session and _current_ctx.session.data then
_current_ctx.session.data[key] = value
end
end
function session_id()
local env = getfenv(2)
if env.ctx and env.ctx.session then
return env.ctx.session.id
if _current_ctx and _current_ctx.session then
return _current_ctx.session.id
end
return nil
end
function session_get_all()
local env = getfenv(2)
if env.ctx and env.ctx.session then
return env.ctx.session.data
if _current_ctx and _current_ctx.session and _current_ctx.session.data then
-- Return a copy to prevent modification
local copy = {}
for k, v in pairs(_current_ctx.session.data) do
copy[k] = v
end
return nil
return copy
end
return {}
end
function session_delete(key)
@ -345,17 +347,16 @@ function session_delete(key)
resp.session = resp.session or {}
resp.session[key] = "__SESSION_DELETE_MARKER__"
local env = getfenv(2)
if env.ctx and env.ctx.session and env.ctx.session.data then
env.ctx.session.data[key] = nil
-- Update current context
if _current_ctx and _current_ctx.session and _current_ctx.session.data then
_current_ctx.session.data[key] = nil
end
end
function session_clear()
local env = getfenv(2)
if env.ctx and env.ctx.session and env.ctx.session.data then
for k, _ in pairs(env.ctx.session.data) do
env.ctx.session.data[k] = nil
if _current_ctx and _current_ctx.session and _current_ctx.session.data then
for k, _ in pairs(_current_ctx.session.data) do
_current_ctx.session.data[k] = nil
end
end
@ -384,11 +385,7 @@ function csrf_field()
end
function csrf_validate()
local env = getfenv(2)
local token = false
if env.ctx and env.ctx.session and env.ctx.session.data then
token = env.ctx.session.data["_csrf_token"]
end
local token = session_get("_csrf_token")
if not token then
http_set_status(403)
@ -397,13 +394,13 @@ function csrf_validate()
end
local request_token = nil
if env.ctx and env.ctx.form then
request_token = env.ctx.form._csrf_token
if _current_ctx and _current_ctx.form then
request_token = _current_ctx.form._csrf_token
end
if not request_token and env.ctx and env.ctx._request_headers then
request_token = env.ctx._request_headers["x-csrf-token"] or
env.ctx._request_headers["csrf-token"]
if not request_token and _current_ctx and _current_ctx._request_headers then
request_token = _current_ctx._request_headers["x-csrf-token"] or
_current_ctx._request_headers["csrf-token"]
end
if not request_token or request_token ~= token then

View File

@ -177,12 +177,15 @@ func extractHTTPResponseData(state *luajit.State, response *Response) {
// Extract headers using ForEachTableKV
if headerTable, ok := state.GetFieldTable(-1, "headers"); ok {
if headers, ok := headerTable.(map[string]any); ok {
switch headers := headerTable.(type) {
case map[string]any:
for k, v := range headers {
if str, ok := v.(string); ok {
response.Headers[k] = str
}
}
case map[string]string:
maps.Copy(response.Headers, headers)
}
}
@ -207,8 +210,19 @@ func extractHTTPResponseData(state *luajit.State, response *Response) {
// Extract session data
if session, ok := state.GetFieldTable(-1, "session"); ok {
if sessMap, ok := session.(map[string]any); ok {
switch sessMap := session.(type) {
case map[string]any:
maps.Copy(response.SessionData, sessMap)
case map[string]string:
for k, v := range sessMap {
response.SessionData[k] = v
}
case map[string]int:
for k, v := range sessMap {
response.SessionData[k] = v
}
default:
logger.Debugf("Unexpected session type: %T", session)
}
}

View File

@ -21,24 +21,21 @@ var (
dbPools = make(map[string]*sqlitex.Pool)
poolsMu sync.RWMutex
dataDir string
poolSize = 8 // Default, will be set to match runner pool size
poolSize = 8
connTimeout = 5 * time.Second
)
// InitSQLite initializes the SQLite subsystem
func InitSQLite(dir string) {
dataDir = dir
logger.Infof("SQLite is g2g! %s", color.Yellow(dir))
}
// SetSQLitePoolSize sets the pool size to match the runner pool size
func SetSQLitePoolSize(size int) {
if size > 0 {
poolSize = size
}
}
// CleanupSQLite closes all database connections
func CleanupSQLite() {
poolsMu.Lock()
defer poolsMu.Unlock()
@ -53,15 +50,12 @@ func CleanupSQLite() {
logger.Debugf("SQLite connections closed")
}
// getPool returns a connection pool for the database
func getPool(dbName string) (*sqlitex.Pool, error) {
// Validate database name
dbName = filepath.Base(dbName)
if dbName == "" || dbName[0] == '.' {
return nil, fmt.Errorf("invalid database name")
}
// Check for existing pool
poolsMu.RLock()
pool, exists := dbPools[dbName]
if exists {
@ -70,21 +64,17 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
}
poolsMu.RUnlock()
// Create new pool under write lock
poolsMu.Lock()
defer poolsMu.Unlock()
// Double-check if a pool was created while waiting for lock
if pool, exists = dbPools[dbName]; exists {
return pool, nil
}
// Create new pool with proper size
dbPath := filepath.Join(dataDir, dbName+".db")
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{
PoolSize: poolSize,
PrepareConn: func(conn *sqlite.Conn) error {
// Execute PRAGMA statements individually
pragmas := []string{
"PRAGMA journal_mode = WAL",
"PRAGMA synchronous = NORMAL",
@ -109,7 +99,6 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
return pool, nil
}
// sqlQuery executes a SQL query and returns results
func sqlQuery(state *luajit.State) int {
if err := state.CheckMinArgs(2); err != nil {
return state.PushError("sqlite.query: %v", err)
@ -125,13 +114,11 @@ func sqlQuery(state *luajit.State) int {
return state.PushError("sqlite.query: query must be string")
}
// Get pool
pool, err := getPool(dbName)
if err != nil {
return state.PushError("sqlite.query: %v", err)
}
// Get connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
@ -141,18 +128,15 @@ func sqlQuery(state *luajit.State) int {
}
defer pool.Put(conn)
// Create execution options
var execOpts sqlitex.ExecOptions
rows := make([]map[string]any, 0, 16)
rows := make([]any, 0, 16)
// Set up parameters if provided
if state.GetTop() >= 3 && !state.IsNil(3) {
if err := setupParams(state, 3, &execOpts); err != nil {
return state.PushError("sqlite.query: %v", err)
}
}
// Set up result function
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
row := make(map[string]any)
colCount := stmt.ColumnCount()
@ -182,12 +166,10 @@ func sqlQuery(state *luajit.State) int {
return nil
}
// Execute query
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
return state.PushError("sqlite.query: %v", err)
}
// Create result using specific map type and PushValue
if err := state.PushValue(rows); err != nil {
return state.PushError("sqlite.query: %v", err)
}
@ -195,7 +177,6 @@ func sqlQuery(state *luajit.State) int {
return 1
}
// sqlExec executes a SQL statement without returning results
func sqlExec(state *luajit.State) int {
if err := state.CheckMinArgs(2); err != nil {
return state.PushError("sqlite.exec: %v", err)
@ -211,13 +192,11 @@ func sqlExec(state *luajit.State) int {
return state.PushError("sqlite.exec: query must be string")
}
// Get pool
pool, err := getPool(dbName)
if err != nil {
return state.PushError("sqlite.exec: %v", err)
}
// Get connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
@ -227,10 +206,8 @@ func sqlExec(state *luajit.State) int {
}
defer pool.Put(conn)
// Check if parameters are provided
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
// Fast path for multi-statement scripts
if strings.Contains(query, ";") && !hasParams {
if err := sqlitex.ExecScript(conn, query); err != nil {
return state.PushError("sqlite.exec: %v", err)
@ -239,7 +216,6 @@ func sqlExec(state *luajit.State) int {
return 1
}
// Fast path for simple queries with no parameters
if !hasParams {
if err := sqlitex.Execute(conn, query, nil); err != nil {
return state.PushError("sqlite.exec: %v", err)
@ -248,23 +224,19 @@ func sqlExec(state *luajit.State) int {
return 1
}
// Create execution options for parameterized query
var execOpts sqlitex.ExecOptions
if err := setupParams(state, 3, &execOpts); err != nil {
return state.PushError("sqlite.exec: %v", err)
}
// Execute with parameters
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
return state.PushError("sqlite.exec: %v", err)
}
// Return affected rows
state.PushNumber(float64(conn.Changes()))
return 1
}
// setupParams configures execution options with parameters from Lua
func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error {
if state.IsTable(paramIndex) {
paramsAny, err := state.SafeToTable(paramIndex)
@ -272,13 +244,8 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti
return fmt.Errorf("invalid parameters: %w", err)
}
// Type assert to map[string]any
params, ok := paramsAny.(map[string]any)
if !ok {
return fmt.Errorf("parameters must be a table")
}
// Check for array-style params
switch params := paramsAny.(type) {
case map[string]any:
if arr, ok := params[""]; ok {
if arrParams, ok := arr.([]any); ok {
execOpts.Args = arrParams
@ -290,7 +257,6 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti
execOpts.Args = args
}
} else {
// Named parameters
named := make(map[string]any, len(params))
for k, v := range params {
if len(k) > 0 && k[0] != ':' {
@ -301,8 +267,64 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti
}
execOpts.Named = named
}
case map[string]string:
named := make(map[string]any, len(params))
for k, v := range params {
if len(k) > 0 && k[0] != ':' {
named[":"+k] = v
} else {
named[k] = v
}
}
execOpts.Named = named
case map[string]int:
named := make(map[string]any, len(params))
for k, v := range params {
if len(k) > 0 && k[0] != ':' {
named[":"+k] = v
} else {
named[k] = v
}
}
execOpts.Named = named
case map[int]any:
named := make(map[string]any, len(params))
for k, v := range params {
named[fmt.Sprintf(":%d", k)] = v
}
execOpts.Named = named
case []any:
execOpts.Args = params
case []string:
args := make([]any, len(params))
for i, v := range params {
args[i] = v
}
execOpts.Args = args
case []int:
args := make([]any, len(params))
for i, v := range params {
args[i] = v
}
execOpts.Args = args
case []float64:
args := make([]any, len(params))
for i, v := range params {
args[i] = v
}
execOpts.Args = args
default:
return fmt.Errorf("unsupported parameter type: %T", params)
}
} else {
// Positional parameters from stack
count := state.GetTop() - 2
args := make([]any, count)
for i := range count {
@ -319,7 +341,6 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti
return nil
}
// sqlGetOne executes a query and returns only the first row
func sqlGetOne(state *luajit.State) int {
if err := state.CheckMinArgs(2); err != nil {
return state.PushError("sqlite.get_one: %v", err)
@ -335,13 +356,11 @@ func sqlGetOne(state *luajit.State) int {
return state.PushError("sqlite.get_one: query must be string")
}
// Get pool
pool, err := getPool(dbName)
if err != nil {
return state.PushError("sqlite.get_one: %v", err)
}
// Get connection with timeout
ctx, cancel := context.WithTimeout(context.Background(), connTimeout)
defer cancel()
@ -351,21 +370,18 @@ func sqlGetOne(state *luajit.State) int {
}
defer pool.Put(conn)
// Create execution options
var execOpts sqlitex.ExecOptions
var result map[string]any
// Set up parameters if provided
if state.GetTop() >= 3 && !state.IsNil(3) {
if err := setupParams(state, 3, &execOpts); err != nil {
return state.PushError("sqlite.get_one: %v", err)
}
}
// Set up result function to get only first row
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
if result != nil {
return nil // Already got first row
return nil
}
result = make(map[string]any)
@ -395,12 +411,10 @@ func sqlGetOne(state *luajit.State) int {
return nil
}
// Execute query
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
return state.PushError("sqlite.get_one: %v", err)
}
// Return result or nil if no rows
if result == nil {
state.PushNil()
} else {
@ -412,7 +426,6 @@ func sqlGetOne(state *luajit.State) int {
return 1
}
// RegisterSQLiteFunctions registers SQLite functions with the Lua state
func RegisterSQLiteFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
return err