diff --git a/core/runner/fs.lua b/core/runner/fs.lua index 21cb875..6f38bc7 100644 --- a/core/runner/fs.lua +++ b/core/runner/fs.lua @@ -1,29 +1,28 @@ local fs = {} --- File Operations -fs.read_file = function(path) +fs.read = function(path) if type(path) ~= "string" then - error("fs.read_file: path must be a string", 2) + error("fs.read: path must be a string", 2) end return __fs_read_file(path) end -fs.write_file = function(path, content) +fs.write = function(path, content) if type(path) ~= "string" then - error("fs.write_file: path must be a string", 2) + error("fs.write: path must be a string", 2) end if type(content) ~= "string" then - error("fs.write_file: content must be a string", 2) + error("fs.write: content must be a string", 2) end return __fs_write_file(path, content) end -fs.append_file = function(path, content) +fs.append = function(path, content) if type(path) ~= "string" then - error("fs.append_file: path must be a string", 2) + error("fs.append: path must be a string", 2) end if type(content) ~= "string" then - error("fs.append_file: content must be a string", 2) + error("fs.append: content must be a string", 2) end return __fs_append_file(path, content) end @@ -35,16 +34,16 @@ fs.exists = function(path) return __fs_exists(path) end -fs.remove_file = function(path) +fs.remove = function(path) if type(path) ~= "string" then - error("fs.remove_file: path must be a string", 2) + error("fs.remove: path must be a string", 2) end return __fs_remove_file(path) end -fs.get_info = function(path) +fs.info = function(path) if type(path) ~= "string" then - error("fs.get_info: path must be a string", 2) + error("fs.info: path must be a string", 2) end local info = __fs_get_info(path) @@ -57,24 +56,24 @@ fs.get_info = function(path) end -- Directory Operations -fs.make_dir = function(path, mode) +fs.mkdir = function(path, mode) if type(path) ~= "string" then - error("fs.make_dir: path must be a string", 2) + error("fs.mkdir: path must be a string", 2) end mode = mode or 0755 return __fs_make_dir(path, mode) end -fs.list_dir = function(path) +fs.ls = function(path) if type(path) ~= "string" then - error("fs.list_dir: path must be a string", 2) + error("fs.ls: path must be a string", 2) end return __fs_list_dir(path) end -fs.remove_dir = function(path, recursive) +fs.rmdir = function(path, recursive) if type(path) ~= "string" then - error("fs.remove_dir: path must be a string", 2) + error("fs.rmdir: path must be a string", 2) end recursive = recursive or false return __fs_remove_dir(path, recursive) diff --git a/core/runner/password.go b/core/runner/password.go index 30c391a..d96b348 100644 --- a/core/runner/password.go +++ b/core/runner/password.go @@ -28,7 +28,7 @@ func passwordHash(state *luajit.State) int { password := state.ToString(1) params := &argon2id.Params{ - Memory: 64 * 1024, + Memory: 128 * 1024, Iterations: 4, Parallelism: 4, SaltLength: 16, diff --git a/core/runner/sandbox.lua b/core/runner/sandbox.lua index c0cdb75..5eede95 100644 --- a/core/runner/sandbox.lua +++ b/core/runner/sandbox.lua @@ -585,8 +585,8 @@ local password = {} -- Hash a password using Argon2id -- Options: --- memory: Amount of memory to use in KB (default: 64MB) --- iterations: Number of iterations (default: 3) +-- memory: Amount of memory to use in KB (default: 128MB) +-- iterations: Number of iterations (default: 4) -- parallelism: Number of threads (default: 4) -- salt_length: Length of salt in bytes (default: 16) -- key_length: Length of the derived key in bytes (default: 32) diff --git a/core/runner/sqlite.go b/core/runner/sqlite.go index 3d3015c..4d1879f 100644 --- a/core/runner/sqlite.go +++ b/core/runner/sqlite.go @@ -33,7 +33,6 @@ type SQLiteManager struct { dataDir string } -// Global manager var sqliteManager *SQLiteManager // InitSQLite initializes the SQLite manager @@ -55,7 +54,7 @@ func CleanupSQLite() { sqliteManager.mu.Lock() defer sqliteManager.mu.Unlock() - // Release all active connections + // Release all connections and close pools for id, conn := range sqliteManager.activeConns { if conn.Pool != nil { conn.Pool.Put(conn.Conn) @@ -63,7 +62,6 @@ func CleanupSQLite() { delete(sqliteManager.activeConns, id) } - // Close all pools for name, pool := range sqliteManager.pools { if err := pool.Close(); err != nil { logger.Error("Failed to close database %s: %v", name, err) @@ -94,13 +92,10 @@ func ReleaseActiveConnections(state *luajit.State) { // Iterate through active connections state.PushNil() // Start iteration for state.Next(-2) { - // Stack now has key at -2 and value at -1 if state.IsTable(-1) { state.GetField(-1, "id") if state.IsString(-1) { connID := state.ToString(-1) - - // Release connection from Go side if conn, exists := sqliteManager.activeConns[connID]; exists { if conn.Pool != nil { conn.Pool.Put(conn.Conn) @@ -130,16 +125,15 @@ func getPool(dbName string) (*sqlitex.Pool, error) { return nil, errors.New("invalid database name") } - // Check for existing pool + // Check for existing pool with read lock sqliteManager.mu.RLock() pool, exists := sqliteManager.pools[dbName] sqliteManager.mu.RUnlock() - if exists { return pool, nil } - // Create new pool + // Create new pool with write lock sqliteManager.mu.Lock() defer sqliteManager.mu.Unlock() @@ -148,12 +142,9 @@ func getPool(dbName string) (*sqlitex.Pool, error) { return pool, nil } - // Create database file path + // Create database file path and pool dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db") - - // Create the pool pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{}) - if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } @@ -179,7 +170,7 @@ func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, e return nil, nil, err } - // Get a connection using the newer Take API + // Get a connection dbConn, err := pool.Take(context.Background()) if err != nil { return nil, nil, fmt.Errorf("failed to get connection from pool: %w", err) @@ -197,128 +188,20 @@ func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, e return dbConn, pool, nil } -// detectParamType determines if parameters are positional or named -func detectParamType(params any) (isArray bool) { - if params == nil { - return false - } - - // Check if it's a map[string]any - if paramsMap, ok := params.(map[string]any); ok { - // Check for the empty string key which indicates an array - if array, hasArray := paramsMap[""]; hasArray { - // Verify it's actually an array - if _, isSlice := array.([]any); isSlice { - return true - } - if _, isFloatSlice := array.([]float64); isFloatSlice { - return true - } - } - return false - } - - // If it's already a slice type - if _, ok := params.([]any); ok { - return true - } - if _, ok := params.([]float64); ok { - return true - } - - return false -} - -// prepareParams processes parameters for SQLite queries -func prepareParams(params any) (map[string]any, []any) { - if params == nil { - return nil, nil - } - - // Handle positional parameters (array-like) - if detectParamType(params) { - var positional []any - - // Extract array from special map format - if paramsMap, ok := params.(map[string]any); ok { - if array, hasArray := paramsMap[""]; hasArray { - if slice, ok := array.([]any); ok { - positional = slice - } else if floatSlice, ok := array.([]float64); ok { - // Convert []float64 to []any - positional = make([]any, len(floatSlice)) - for i, v := range floatSlice { - positional[i] = v - } - } - } - } else if slice, ok := params.([]any); ok { - positional = slice - } else if floatSlice, ok := params.([]float64); ok { - // Convert []float64 to []any - positional = make([]any, len(floatSlice)) - for i, v := range floatSlice { - positional[i] = v - } - } - - return nil, positional - } - - // Handle named parameters (map-like) - if paramsMap, ok := params.(map[string]any); ok { - modified := make(map[string]any, len(paramsMap)) - - for key, value := range paramsMap { - if len(key) > 0 && key[0] != ':' { - modified[":"+key] = value - } else { - modified[key] = value - } - } - - return modified, nil - } - - return nil, nil -} - -// luaSQLQuery executes a SQL query and returns results to Lua -func luaSQLQuery(state *luajit.State) int { - // Get database name - if !state.IsString(1) { - state.PushString("sqlite.query: database name must be a string") - return -1 - } - dbName := state.ToString(1) - - // Get query - if !state.IsString(2) { - state.PushString("sqlite.query: query must be a string") - return -1 - } - query := state.ToString(2) +// processParams extracts parameters and connection ID from Lua state +func processParams(state *luajit.State, defaultConnID string) (params any, connID string, isPositional bool, positionalParams []any, err error) { + connID = defaultConnID // Check if using positional parameters - isPositional := false - var positionalParams []any - - // Get connection ID (optional) - var connID string - - // Check if we have positional parameters instead of a params table if state.GetTop() >= 3 && !state.IsTable(3) { isPositional = true paramCount := state.GetTop() - 2 // Count all args after db and query - // Adjust connection ID index if we have positional params - if paramCount > 0 { - // Last parameter might be connID if it's a string - lastIdx := paramCount + 2 // db(1) + query(2) + paramCount - if state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) { - connID = state.ToString(lastIdx) - paramCount-- // Exclude connID from param count - } + // Check if last param is a connection ID + lastIdx := paramCount + 2 // db(1) + query(2) + paramCount + if paramCount > 0 && state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) { + connID = state.ToString(lastIdx) + paramCount-- // Exclude connID from param count } // Create array for positional parameters @@ -327,95 +210,170 @@ func luaSQLQuery(state *luajit.State) int { // Collect all parameters for i := 0; i < paramCount; i++ { paramIdx := i + 3 // Params start at index 3 - - // Convert to appropriate Go value - var value any switch state.GetType(paramIdx) { case luajit.TypeNumber: - value = state.ToNumber(paramIdx) + positionalParams[i] = state.ToNumber(paramIdx) case luajit.TypeString: - value = state.ToString(paramIdx) + positionalParams[i] = state.ToString(paramIdx) case luajit.TypeBoolean: - value = state.ToBoolean(paramIdx) + positionalParams[i] = state.ToBoolean(paramIdx) case luajit.TypeNil: - value = nil + positionalParams[i] = nil default: - // Try to convert as generic value - var err error - value, err = state.ToValue(paramIdx) - if err != nil { - state.PushString(fmt.Sprintf("sqlite.query: failed to convert parameter %d: %s", i+1, err.Error())) - return -1 + val, errConv := state.ToValue(paramIdx) + if errConv != nil { + return nil, "", false, nil, fmt.Errorf("failed to convert parameter %d: %w", i+1, errConv) + } + positionalParams[i] = val + } + } + return nil, connID, isPositional, positionalParams, nil + } + + // Named parameter handling + if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) { + connID = state.ToString(4) + } + + // Get table parameters if present + if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) { + params, err = state.ToTable(3) + } + + return params, connID, isPositional, nil, err +} + +// prepareExecOptions prepares SQLite execution options based on parameters +func prepareExecOptions(query string, params any, isPositional bool, positionalParams []any) *sqlitex.ExecOptions { + execOpts := &sqlitex.ExecOptions{} + + if params == nil && !isPositional { + return execOpts + } + + // Prepare parameters + isArray := false + var namedParams map[string]any + var arrParams []any + + // Check for array parameters + if m, ok := params.(map[string]any); ok { + if arr, hasArray := m[""]; hasArray { + isArray = true + if slice, ok := arr.([]any); ok { + arrParams = slice + } else if floatSlice, ok := arr.([]float64); ok { + arrParams = make([]any, len(floatSlice)) + for i, v := range floatSlice { + arrParams[i] = v } } - - positionalParams[i] = value - } - } else { - // Original named parameter table handling - if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) { - connID = state.ToString(4) } else { - // Generate a temporary connection ID - connID = fmt.Sprintf("temp_%p", &query) + // Process named parameters + namedParams = make(map[string]any, len(m)) + for k, v := range m { + if len(k) > 0 && k[0] != ':' { + namedParams[":"+k] = v + } else { + namedParams[k] = v + } + } + } + } else if slice, ok := params.([]any); ok { + isArray = true + arrParams = slice + } else if floatSlice, ok := params.([]float64); ok { + isArray = true + arrParams = make([]any, len(floatSlice)) + for i, v := range floatSlice { + arrParams[i] = v } } - // Get parameters (optional for named parameters) - var params any - if !isPositional && state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) { - var err error - params, err = state.ToTable(3) - if err != nil { - state.PushString("sqlite.query: failed to parse parameters: " + err.Error()) - return -1 + // Use positional params if explicitly provided + if isPositional { + arrParams = positionalParams + isArray = true + } + + // Limit positional params to actual placeholders + if isArray && arrParams != nil { + placeholderCount := strings.Count(query, "?") + if len(arrParams) > placeholderCount { + arrParams = arrParams[:placeholderCount] } + execOpts.Args = arrParams + } else if namedParams != nil { + execOpts.Named = namedParams + } + + return execOpts +} + +// sqlOperation handles both query and exec operations +func sqlOperation(state *luajit.State, isQuery bool) int { + operation := "query" + if !isQuery { + operation = "exec" + } + + // Get database name + if !state.IsString(1) { + state.PushString(fmt.Sprintf("sqlite.%s: database name must be a string", operation)) + return -1 + } + dbName := state.ToString(1) + + // Get query + if !state.IsString(2) { + state.PushString(fmt.Sprintf("sqlite.%s: query must be a string", operation)) + return -1 + } + query := state.ToString(2) + + // Generate a temporary connection ID if needed + defaultConnID := fmt.Sprintf("temp_%p", &query) + + // Process parameters and get connection ID + params, connID, isPositional, positionalParams, err := processParams(state, defaultConnID) + if err != nil { + state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error())) + return -1 } // Get connection conn, pool, err := getConnection(dbName, connID) if err != nil { - state.PushString("sqlite.query: " + err.Error()) + state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error())) return -1 } // For temporary connections, defer release if strings.HasPrefix(connID, "temp_") { defer func() { - // Release the connection sqliteManager.mu.Lock() delete(sqliteManager.activeConns, connID) sqliteManager.mu.Unlock() - pool.Put(conn) }() } - // Execute query and collect results + // Prepare execution options + execOpts := prepareExecOptions(query, params, isPositional, positionalParams) + + // Define rows slice outside the closure var rows []map[string]any - // Prepare params based on type - namedParams, positional := prepareParams(params) - - // If we have direct positional params from function args, use those - if isPositional { - positional = positionalParams - } - - // Count actual placeholders in the query - placeholderCount := strings.Count(query, "?") - - // Execute with appropriate parameter type - execOpts := &sqlitex.ExecOptions{ - ResultFunc: func(stmt *sqlite.Stmt) error { + // For queries, add result function + if isQuery { + execOpts.ResultFunc = func(stmt *sqlite.Stmt) error { row := make(map[string]any) columnCount := stmt.ColumnCount() for i := range columnCount { columnName := stmt.ColumnName(i) - columnType := stmt.ColumnType(i) - switch columnType { + switch stmt.ColumnType(i) { case sqlite.TypeInteger: row[columnName] = stmt.ColumnInt64(i) case sqlite.TypeFloat: @@ -425,8 +383,7 @@ func luaSQLQuery(state *luajit.State) int { case sqlite.TypeBlob: blobSize := stmt.ColumnLen(i) buf := make([]byte, blobSize) - blob := stmt.ColumnBytes(i, buf) - row[columnName] = blob + row[columnName] = stmt.ColumnBytes(i, buf) case sqlite.TypeNull: row[columnName] = nil } @@ -437,194 +394,51 @@ func luaSQLQuery(state *luajit.State) int { maps.Copy(rowCopy, row) rows = append(rows, rowCopy) return nil - }, - } - - // Set appropriate parameter type - if namedParams != nil { - execOpts.Named = namedParams - } else if positional != nil { - // Make sure we're not passing more positional parameters than placeholders - if len(positional) > placeholderCount { - positional = positional[:placeholderCount] } - execOpts.Args = positional } - err = sqlitex.Execute(conn, query, execOpts) + // Execute query + var execErr error + if isQuery || execOpts.Args != nil || execOpts.Named != nil { + execErr = sqlitex.Execute(conn, query, execOpts) + } else { + // Use ExecScript for queries without parameters + execErr = sqlitex.ExecScript(conn, query) + } - if err != nil { - state.PushString("sqlite.query: " + err.Error()) + if execErr != nil { + state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, execErr.Error())) return -1 } - // Create result table - state.NewTable() - - // Add results to the table - for i, row := range rows { - state.PushNumber(float64(i + 1)) - if err := state.PushTable(row); err != nil { - state.PushString("sqlite.query: " + err.Error()) - return -1 + // Return results for query, affected rows for exec + if isQuery { + // Create result table with rows + state.NewTable() + for i, row := range rows { + state.PushNumber(float64(i + 1)) + if err := state.PushTable(row); err != nil { + state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error())) + return -1 + } + state.SetTable(-3) } - state.SetTable(-3) + } else { + // Return number of affected rows + state.PushNumber(float64(conn.Changes())) } return 1 } +// luaSQLQuery executes a SQL query and returns results to Lua +func luaSQLQuery(state *luajit.State) int { + return sqlOperation(state, true) +} + // luaSQLExec executes a SQL statement without returning results func luaSQLExec(state *luajit.State) int { - // Get database name and query - if !state.IsString(1) { - state.PushString("sqlite.exec: database name must be a string") - return -1 - } - dbName := state.ToString(1) - - if !state.IsString(2) { - state.PushString("sqlite.exec: query must be a string") - return -1 - } - query := state.ToString(2) - - // Check if using positional parameters - isPositional := false - var positionalParams []any - - // Get connection ID (optional) - var connID string - - // Check if we have positional parameters instead of a params table - if state.GetTop() >= 3 && !state.IsTable(3) { - isPositional = true - paramCount := state.GetTop() - 2 // Count all args after db and query - - // Adjust connection ID index if we have positional params - if paramCount > 0 { - // Last parameter might be connID if it's a string - lastIdx := paramCount + 2 // db(1) + query(2) + paramCount - if state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) { - connID = state.ToString(lastIdx) - paramCount-- // Exclude connID from param count - } - } - - // Create array for positional parameters - positionalParams = make([]any, paramCount) - - // Collect all parameters - for i := 0; i < paramCount; i++ { - paramIdx := i + 3 // Params start at index 3 - - // Convert to appropriate Go value - var value any - switch state.GetType(paramIdx) { - case luajit.TypeNumber: - value = state.ToNumber(paramIdx) - case luajit.TypeString: - value = state.ToString(paramIdx) - case luajit.TypeBoolean: - value = state.ToBoolean(paramIdx) - case luajit.TypeNil: - value = nil - default: - // Try to convert as generic value - var err error - value, err = state.ToValue(paramIdx) - if err != nil { - state.PushString(fmt.Sprintf("sqlite.exec: failed to convert parameter %d: %s", i+1, err.Error())) - return -1 - } - } - - positionalParams[i] = value - } - } else { - // Original named parameter table handling - if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) { - connID = state.ToString(4) - } else { - // Generate a temporary connection ID - connID = fmt.Sprintf("temp_%p", &query) - } - } - - // Get parameters (optional for named parameters) - var params any - if !isPositional && state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) { - var err error - params, err = state.ToTable(3) - if err != nil { - state.PushString("sqlite.exec: failed to parse parameters: " + err.Error()) - return -1 - } - } - - // Get connection - conn, pool, err := getConnection(dbName, connID) - if err != nil { - state.PushString("sqlite.exec: " + err.Error()) - return -1 - } - - // For temporary connections, defer release - if strings.HasPrefix(connID, "temp_") { - defer func() { - // Release the connection - sqliteManager.mu.Lock() - delete(sqliteManager.activeConns, connID) - sqliteManager.mu.Unlock() - - pool.Put(conn) - }() - } - - // Count actual placeholders in the query - placeholderCount := strings.Count(query, "?") - - // Prepare params based on type - namedParams, positional := prepareParams(params) - - // If we have direct positional params from function args, use those - if isPositional { - positional = positionalParams - } - - // Ensure we don't pass more parameters than placeholders - if positional != nil && len(positional) > placeholderCount { - positional = positional[:placeholderCount] - } - - // Execute with appropriate parameter type - var execErr error - - if isPositional || positional != nil { - // Execute with positional parameters - execOpts := &sqlitex.ExecOptions{ - Args: positional, - } - execErr = sqlitex.Execute(conn, query, execOpts) - } else if namedParams != nil { - // Execute with named parameters - execOpts := &sqlitex.ExecOptions{ - Named: namedParams, - } - execErr = sqlitex.Execute(conn, query, execOpts) - } else { - // Execute without parameters - execErr = sqlitex.ExecScript(conn, query) - } - - if execErr != nil { - state.PushString("sqlite.exec: " + execErr.Error()) - return -1 - } - - // Return number of affected rows - state.PushNumber(float64(conn.Changes())) - return 1 + return sqlOperation(state, false) } // RegisterSQLiteFunctions registers SQLite functions with the Lua state @@ -632,10 +446,5 @@ func RegisterSQLiteFunctions(state *luajit.State) error { if err := state.RegisterGoFunction("__sqlite_query", luaSQLQuery); err != nil { return err } - - if err := state.RegisterGoFunction("__sqlite_exec", luaSQLExec); err != nil { - return err - } - - return nil + return state.RegisterGoFunction("__sqlite_exec", luaSQLExec) }