diff --git a/go.mod b/go.mod index eca13be..f4165c4 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.24.1 require ( git.sharkk.net/Go/LRU v1.0.0 git.sharkk.net/Sharkk/Fin v1.2.0 - git.sharkk.net/Sky/LuaJIT-to-Go v0.5.0 + git.sharkk.net/Sky/LuaJIT-to-Go v0.5.1 github.com/VictoriaMetrics/fastcache v1.12.4 github.com/alexedwards/argon2id v1.0.0 github.com/deneonet/benc v1.1.8 diff --git a/go.sum b/go.sum index ca197a2..af87860 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ git.sharkk.net/Go/LRU v1.0.0 h1:/KqdRVhHldi23aVfQZ4ss6vhCWZqA3vFiQyf1MJPpQc= git.sharkk.net/Go/LRU v1.0.0/go.mod h1:8tdTyl85mss9a+KKwo+Wj9gKHOizhfLfpJhz1ltYz50= git.sharkk.net/Sharkk/Fin v1.2.0 h1:axhme8vHRYoaB3us7PNfXzXxKOxhpS5BMuNpN8ESe6U= git.sharkk.net/Sharkk/Fin v1.2.0/go.mod h1:ca0Ej9yCM/vHh1o3YMvBZspme3EtbwoEL2UXN5UPXMo= -git.sharkk.net/Sky/LuaJIT-to-Go v0.5.0 h1:WzIKbNIoP+P56n7EmkD9V1QZJUNMbTm3cJj2jc5qUfI= -git.sharkk.net/Sky/LuaJIT-to-Go v0.5.0/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8= +git.sharkk.net/Sky/LuaJIT-to-Go v0.5.1 h1:e9rby0xJs8m2SAPv0di/LplDok88UyjcNjKu8S4d1BY= +git.sharkk.net/Sky/LuaJIT-to-Go v0.5.1/go.mod h1:HQz+D7AFxOfNbTIogjxP+shEBtz1KKrLlLucU+w07c8= github.com/VictoriaMetrics/fastcache v1.12.4 h1:2xvmwZBW+9QtHsXggfzAZRs1FZWCsBs8QDg22bMidf0= github.com/VictoriaMetrics/fastcache v1.12.4/go.mod h1:K+JGPBn0sueFlLjZ8rcVM0cKkWKNElKyQXmw57QOoYI= github.com/alexedwards/argon2id v1.0.0 h1:wJzDx66hqWX7siL/SRUmgz3F8YMrd/nfX/xHHcQQP0w= diff --git a/runner/lua/sandbox.lua b/runner/lua/sandbox.lua index d0783d3..e5b9357 100644 --- a/runner/lua/sandbox.lua +++ b/runner/lua/sandbox.lua @@ -1,5 +1,5 @@ --[[ -sandbox.lua +sandbox.lua - Rewritten with global context storage ]]-- __http_response = {} @@ -8,6 +8,9 @@ __module_bytecode = {} __ready_modules = {} __EXIT_SENTINEL = {} -- Unique object for exit identification +-- Global context storage for reliable access +local _current_ctx = nil + -- ====================================================================== -- CORE SANDBOX FUNCTIONALITY -- ====================================================================== @@ -38,17 +41,20 @@ end -- Execute script with clean environment function __execute_script(fn, ctx) __http_response = nil + _current_ctx = ctx -- Store globally for function access local env = __create_env(ctx) env.exit = exit setfenv(fn, env) local ok, result = pcall(fn) + + _current_ctx = nil -- Clean up after execution + if not ok then if result == __EXIT_SENTINEL then return end - error(result, 0) end @@ -258,14 +264,13 @@ function cookie_get(name) error("cookie_get: name must be a string", 2) end - local env = getfenv(2) - - if env.ctx and env.ctx.cookies then - return env.ctx.cookies[name] - end - - if env.ctx and env.ctx._request_cookies then - return env.ctx._request_cookies[name] + if _current_ctx then + if _current_ctx.cookies then + return _current_ctx.cookies[name] + end + if _current_ctx._request_cookies then + return _current_ctx._request_cookies[name] + end end return nil @@ -289,10 +294,8 @@ function session_get(key) error("session_get: key must be a string", 2) end - local env = getfenv(2) - - if env.ctx and env.ctx.session and env.ctx.session.data then - return env.ctx.session.data[key] + if _current_ctx and _current_ctx.session and _current_ctx.session.data then + return _current_ctx.session.data[key] end return nil @@ -302,7 +305,7 @@ function session_set(key, value) if type(key) ~= "string" then error("session_set: key must be a string", 2) end - if type(value) == nil then + if value == nil then error("session_set: value cannot be nil", 2) end @@ -310,30 +313,29 @@ function session_set(key, value) resp.session = resp.session or {} resp.session[key] = value - local env = getfenv(2) - if env.ctx and env.ctx.session and env.ctx.session.data then - env.ctx.session.data[key] = value + -- Update current context session data + if _current_ctx and _current_ctx.session and _current_ctx.session.data then + _current_ctx.session.data[key] = value end end function session_id() - local env = getfenv(2) - - if env.ctx and env.ctx.session then - return env.ctx.session.id + if _current_ctx and _current_ctx.session then + return _current_ctx.session.id end - return nil end function session_get_all() - local env = getfenv(2) - - if env.ctx and env.ctx.session then - return env.ctx.session.data + if _current_ctx and _current_ctx.session and _current_ctx.session.data then + -- Return a copy to prevent modification + local copy = {} + for k, v in pairs(_current_ctx.session.data) do + copy[k] = v + end + return copy end - - return nil + return {} end function session_delete(key) @@ -345,17 +347,16 @@ function session_delete(key) resp.session = resp.session or {} resp.session[key] = "__SESSION_DELETE_MARKER__" - local env = getfenv(2) - if env.ctx and env.ctx.session and env.ctx.session.data then - env.ctx.session.data[key] = nil + -- Update current context + if _current_ctx and _current_ctx.session and _current_ctx.session.data then + _current_ctx.session.data[key] = nil end end function session_clear() - local env = getfenv(2) - if env.ctx and env.ctx.session and env.ctx.session.data then - for k, _ in pairs(env.ctx.session.data) do - env.ctx.session.data[k] = nil + if _current_ctx and _current_ctx.session and _current_ctx.session.data then + for k, _ in pairs(_current_ctx.session.data) do + _current_ctx.session.data[k] = nil end end @@ -384,11 +385,7 @@ function csrf_field() end function csrf_validate() - local env = getfenv(2) - local token = false - if env.ctx and env.ctx.session and env.ctx.session.data then - token = env.ctx.session.data["_csrf_token"] - end + local token = session_get("_csrf_token") if not token then http_set_status(403) @@ -397,13 +394,13 @@ function csrf_validate() end local request_token = nil - if env.ctx and env.ctx.form then - request_token = env.ctx.form._csrf_token + if _current_ctx and _current_ctx.form then + request_token = _current_ctx.form._csrf_token end - if not request_token and env.ctx and env.ctx._request_headers then - request_token = env.ctx._request_headers["x-csrf-token"] or - env.ctx._request_headers["csrf-token"] + if not request_token and _current_ctx and _current_ctx._request_headers then + request_token = _current_ctx._request_headers["x-csrf-token"] or + _current_ctx._request_headers["csrf-token"] end if not request_token or request_token ~= token then @@ -664,4 +661,4 @@ end function send_binary(content, mime_type) http_set_content_type(mime_type or "application/octet-stream") return content -end +end \ No newline at end of file diff --git a/runner/sandbox.go b/runner/sandbox.go index d9e236e..5d0b925 100644 --- a/runner/sandbox.go +++ b/runner/sandbox.go @@ -177,12 +177,15 @@ func extractHTTPResponseData(state *luajit.State, response *Response) { // Extract headers using ForEachTableKV if headerTable, ok := state.GetFieldTable(-1, "headers"); ok { - if headers, ok := headerTable.(map[string]any); ok { + switch headers := headerTable.(type) { + case map[string]any: for k, v := range headers { if str, ok := v.(string); ok { response.Headers[k] = str } } + case map[string]string: + maps.Copy(response.Headers, headers) } } @@ -207,8 +210,19 @@ func extractHTTPResponseData(state *luajit.State, response *Response) { // Extract session data if session, ok := state.GetFieldTable(-1, "session"); ok { - if sessMap, ok := session.(map[string]any); ok { + switch sessMap := session.(type) { + case map[string]any: maps.Copy(response.SessionData, sessMap) + case map[string]string: + for k, v := range sessMap { + response.SessionData[k] = v + } + case map[string]int: + for k, v := range sessMap { + response.SessionData[k] = v + } + default: + logger.Debugf("Unexpected session type: %T", session) } } diff --git a/runner/sqlite.go b/runner/sqlite.go index 80ca6a5..cb08864 100644 --- a/runner/sqlite.go +++ b/runner/sqlite.go @@ -21,24 +21,21 @@ var ( dbPools = make(map[string]*sqlitex.Pool) poolsMu sync.RWMutex dataDir string - poolSize = 8 // Default, will be set to match runner pool size + poolSize = 8 connTimeout = 5 * time.Second ) -// InitSQLite initializes the SQLite subsystem func InitSQLite(dir string) { dataDir = dir logger.Infof("SQLite is g2g! %s", color.Yellow(dir)) } -// SetSQLitePoolSize sets the pool size to match the runner pool size func SetSQLitePoolSize(size int) { if size > 0 { poolSize = size } } -// CleanupSQLite closes all database connections func CleanupSQLite() { poolsMu.Lock() defer poolsMu.Unlock() @@ -53,15 +50,12 @@ func CleanupSQLite() { logger.Debugf("SQLite connections closed") } -// getPool returns a connection pool for the database func getPool(dbName string) (*sqlitex.Pool, error) { - // Validate database name dbName = filepath.Base(dbName) if dbName == "" || dbName[0] == '.' { return nil, fmt.Errorf("invalid database name") } - // Check for existing pool poolsMu.RLock() pool, exists := dbPools[dbName] if exists { @@ -70,21 +64,17 @@ func getPool(dbName string) (*sqlitex.Pool, error) { } poolsMu.RUnlock() - // Create new pool under write lock poolsMu.Lock() defer poolsMu.Unlock() - // Double-check if a pool was created while waiting for lock if pool, exists = dbPools[dbName]; exists { return pool, nil } - // Create new pool with proper size dbPath := filepath.Join(dataDir, dbName+".db") pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{ PoolSize: poolSize, PrepareConn: func(conn *sqlite.Conn) error { - // Execute PRAGMA statements individually pragmas := []string{ "PRAGMA journal_mode = WAL", "PRAGMA synchronous = NORMAL", @@ -109,7 +99,6 @@ func getPool(dbName string) (*sqlitex.Pool, error) { return pool, nil } -// sqlQuery executes a SQL query and returns results func sqlQuery(state *luajit.State) int { if err := state.CheckMinArgs(2); err != nil { return state.PushError("sqlite.query: %v", err) @@ -125,13 +114,11 @@ func sqlQuery(state *luajit.State) int { return state.PushError("sqlite.query: query must be string") } - // Get pool pool, err := getPool(dbName) if err != nil { return state.PushError("sqlite.query: %v", err) } - // Get connection with timeout ctx, cancel := context.WithTimeout(context.Background(), connTimeout) defer cancel() @@ -141,18 +128,15 @@ func sqlQuery(state *luajit.State) int { } defer pool.Put(conn) - // Create execution options var execOpts sqlitex.ExecOptions - rows := make([]map[string]any, 0, 16) + rows := make([]any, 0, 16) - // Set up parameters if provided if state.GetTop() >= 3 && !state.IsNil(3) { if err := setupParams(state, 3, &execOpts); err != nil { return state.PushError("sqlite.query: %v", err) } } - // Set up result function execOpts.ResultFunc = func(stmt *sqlite.Stmt) error { row := make(map[string]any) colCount := stmt.ColumnCount() @@ -182,12 +166,10 @@ func sqlQuery(state *luajit.State) int { return nil } - // Execute query if err := sqlitex.Execute(conn, query, &execOpts); err != nil { return state.PushError("sqlite.query: %v", err) } - // Create result using specific map type and PushValue if err := state.PushValue(rows); err != nil { return state.PushError("sqlite.query: %v", err) } @@ -195,7 +177,6 @@ func sqlQuery(state *luajit.State) int { return 1 } -// sqlExec executes a SQL statement without returning results func sqlExec(state *luajit.State) int { if err := state.CheckMinArgs(2); err != nil { return state.PushError("sqlite.exec: %v", err) @@ -211,13 +192,11 @@ func sqlExec(state *luajit.State) int { return state.PushError("sqlite.exec: query must be string") } - // Get pool pool, err := getPool(dbName) if err != nil { return state.PushError("sqlite.exec: %v", err) } - // Get connection with timeout ctx, cancel := context.WithTimeout(context.Background(), connTimeout) defer cancel() @@ -227,10 +206,8 @@ func sqlExec(state *luajit.State) int { } defer pool.Put(conn) - // Check if parameters are provided hasParams := state.GetTop() >= 3 && !state.IsNil(3) - // Fast path for multi-statement scripts if strings.Contains(query, ";") && !hasParams { if err := sqlitex.ExecScript(conn, query); err != nil { return state.PushError("sqlite.exec: %v", err) @@ -239,7 +216,6 @@ func sqlExec(state *luajit.State) int { return 1 } - // Fast path for simple queries with no parameters if !hasParams { if err := sqlitex.Execute(conn, query, nil); err != nil { return state.PushError("sqlite.exec: %v", err) @@ -248,23 +224,19 @@ func sqlExec(state *luajit.State) int { return 1 } - // Create execution options for parameterized query var execOpts sqlitex.ExecOptions if err := setupParams(state, 3, &execOpts); err != nil { return state.PushError("sqlite.exec: %v", err) } - // Execute with parameters if err := sqlitex.Execute(conn, query, &execOpts); err != nil { return state.PushError("sqlite.exec: %v", err) } - // Return affected rows state.PushNumber(float64(conn.Changes())) return 1 } -// setupParams configures execution options with parameters from Lua func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOptions) error { if state.IsTable(paramIndex) { paramsAny, err := state.SafeToTable(paramIndex) @@ -272,25 +244,31 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti return fmt.Errorf("invalid parameters: %w", err) } - // Type assert to map[string]any - params, ok := paramsAny.(map[string]any) - if !ok { - return fmt.Errorf("parameters must be a table") - } - - // Check for array-style params - if arr, ok := params[""]; ok { - if arrParams, ok := arr.([]any); ok { - execOpts.Args = arrParams - } else if floatArr, ok := arr.([]float64); ok { - args := make([]any, len(floatArr)) - for i, v := range floatArr { - args[i] = v + switch params := paramsAny.(type) { + case map[string]any: + if arr, ok := params[""]; ok { + if arrParams, ok := arr.([]any); ok { + execOpts.Args = arrParams + } else if floatArr, ok := arr.([]float64); ok { + args := make([]any, len(floatArr)) + for i, v := range floatArr { + args[i] = v + } + execOpts.Args = args } - execOpts.Args = args + } else { + named := make(map[string]any, len(params)) + for k, v := range params { + if len(k) > 0 && k[0] != ':' { + named[":"+k] = v + } else { + named[k] = v + } + } + execOpts.Named = named } - } else { - // Named parameters + + case map[string]string: named := make(map[string]any, len(params)) for k, v := range params { if len(k) > 0 && k[0] != ':' { @@ -300,9 +278,53 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti } } execOpts.Named = named + + case map[string]int: + named := make(map[string]any, len(params)) + for k, v := range params { + if len(k) > 0 && k[0] != ':' { + named[":"+k] = v + } else { + named[k] = v + } + } + execOpts.Named = named + + case map[int]any: + named := make(map[string]any, len(params)) + for k, v := range params { + named[fmt.Sprintf(":%d", k)] = v + } + execOpts.Named = named + + case []any: + execOpts.Args = params + + case []string: + args := make([]any, len(params)) + for i, v := range params { + args[i] = v + } + execOpts.Args = args + + case []int: + args := make([]any, len(params)) + for i, v := range params { + args[i] = v + } + execOpts.Args = args + + case []float64: + args := make([]any, len(params)) + for i, v := range params { + args[i] = v + } + execOpts.Args = args + + default: + return fmt.Errorf("unsupported parameter type: %T", params) } } else { - // Positional parameters from stack count := state.GetTop() - 2 args := make([]any, count) for i := range count { @@ -319,7 +341,6 @@ func setupParams(state *luajit.State, paramIndex int, execOpts *sqlitex.ExecOpti return nil } -// sqlGetOne executes a query and returns only the first row func sqlGetOne(state *luajit.State) int { if err := state.CheckMinArgs(2); err != nil { return state.PushError("sqlite.get_one: %v", err) @@ -335,13 +356,11 @@ func sqlGetOne(state *luajit.State) int { return state.PushError("sqlite.get_one: query must be string") } - // Get pool pool, err := getPool(dbName) if err != nil { return state.PushError("sqlite.get_one: %v", err) } - // Get connection with timeout ctx, cancel := context.WithTimeout(context.Background(), connTimeout) defer cancel() @@ -351,21 +370,18 @@ func sqlGetOne(state *luajit.State) int { } defer pool.Put(conn) - // Create execution options var execOpts sqlitex.ExecOptions var result map[string]any - // Set up parameters if provided if state.GetTop() >= 3 && !state.IsNil(3) { if err := setupParams(state, 3, &execOpts); err != nil { return state.PushError("sqlite.get_one: %v", err) } } - // Set up result function to get only first row execOpts.ResultFunc = func(stmt *sqlite.Stmt) error { if result != nil { - return nil // Already got first row + return nil } result = make(map[string]any) @@ -395,12 +411,10 @@ func sqlGetOne(state *luajit.State) int { return nil } - // Execute query if err := sqlitex.Execute(conn, query, &execOpts); err != nil { return state.PushError("sqlite.get_one: %v", err) } - // Return result or nil if no rows if result == nil { state.PushNil() } else { @@ -412,7 +426,6 @@ func sqlGetOne(state *luajit.State) int { return 1 } -// RegisterSQLiteFunctions registers SQLite functions with the Lua state func RegisterSQLiteFunctions(state *luajit.State) error { if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil { return err