diff --git a/core/runner/password.go b/core/runner/password.go index 754d32a..30c391a 100644 --- a/core/runner/password.go +++ b/core/runner/password.go @@ -29,7 +29,7 @@ func passwordHash(state *luajit.State) int { params := &argon2id.Params{ Memory: 64 * 1024, - Iterations: 3, + Iterations: 4, Parallelism: 4, SaltLength: 16, KeyLength: 32, @@ -38,46 +38,31 @@ func passwordHash(state *luajit.State) int { if state.IsTable(2) { state.GetField(2, "memory") if state.IsNumber(-1) { - params.Memory = uint32(state.ToNumber(-1)) - if params.Memory < 8*1024 { - params.Memory = 8 * 1024 // Minimum 8MB - } + params.Memory = max(uint32(state.ToNumber(-1)), 8*1024) } state.Pop(1) state.GetField(2, "iterations") if state.IsNumber(-1) { - params.Iterations = uint32(state.ToNumber(-1)) - if params.Iterations < 1 { - params.Iterations = 1 // Minimum 1 iteration - } + params.Iterations = max(uint32(state.ToNumber(-1)), 1) } state.Pop(1) state.GetField(2, "parallelism") if state.IsNumber(-1) { - params.Parallelism = uint8(state.ToNumber(-1)) - if params.Parallelism < 1 { - params.Parallelism = 1 // Minimum 1 thread - } + params.Parallelism = max(uint8(state.ToNumber(-1)), 1) } state.Pop(1) state.GetField(2, "salt_length") if state.IsNumber(-1) { - params.SaltLength = uint32(state.ToNumber(-1)) - if params.SaltLength < 8 { - params.SaltLength = 8 // Minimum 8 bytes - } + params.SaltLength = max(uint32(state.ToNumber(-1)), 8) } state.Pop(1) state.GetField(2, "key_length") if state.IsNumber(-1) { - params.KeyLength = uint32(state.ToNumber(-1)) - if params.KeyLength < 16 { - params.KeyLength = 16 // Minimum 16 bytes - } + params.KeyLength = max(uint32(state.ToNumber(-1)), 16) } state.Pop(1) } diff --git a/core/runner/sqlite.go b/core/runner/sqlite.go index 17fa80c..3d3015c 100644 --- a/core/runner/sqlite.go +++ b/core/runner/sqlite.go @@ -1,6 +1,7 @@ package runner import ( + "context" "errors" "fmt" "path/filepath" @@ -178,10 +179,10 @@ func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, e return nil, nil, err } - // Get a connection - dbConn := pool.Get(nil) - if dbConn == nil { - return nil, nil, errors.New("failed to get connection from pool") + // Get a connection using the newer Take API + dbConn, err := pool.Take(context.Background()) + if err != nil { + return nil, nil, fmt.Errorf("failed to get connection from pool: %w", err) } // Store connection @@ -196,6 +197,92 @@ func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, e return dbConn, pool, nil } +// detectParamType determines if parameters are positional or named +func detectParamType(params any) (isArray bool) { + if params == nil { + return false + } + + // Check if it's a map[string]any + if paramsMap, ok := params.(map[string]any); ok { + // Check for the empty string key which indicates an array + if array, hasArray := paramsMap[""]; hasArray { + // Verify it's actually an array + if _, isSlice := array.([]any); isSlice { + return true + } + if _, isFloatSlice := array.([]float64); isFloatSlice { + return true + } + } + return false + } + + // If it's already a slice type + if _, ok := params.([]any); ok { + return true + } + if _, ok := params.([]float64); ok { + return true + } + + return false +} + +// prepareParams processes parameters for SQLite queries +func prepareParams(params any) (map[string]any, []any) { + if params == nil { + return nil, nil + } + + // Handle positional parameters (array-like) + if detectParamType(params) { + var positional []any + + // Extract array from special map format + if paramsMap, ok := params.(map[string]any); ok { + if array, hasArray := paramsMap[""]; hasArray { + if slice, ok := array.([]any); ok { + positional = slice + } else if floatSlice, ok := array.([]float64); ok { + // Convert []float64 to []any + positional = make([]any, len(floatSlice)) + for i, v := range floatSlice { + positional[i] = v + } + } + } + } else if slice, ok := params.([]any); ok { + positional = slice + } else if floatSlice, ok := params.([]float64); ok { + // Convert []float64 to []any + positional = make([]any, len(floatSlice)) + for i, v := range floatSlice { + positional[i] = v + } + } + + return nil, positional + } + + // Handle named parameters (map-like) + if paramsMap, ok := params.(map[string]any); ok { + modified := make(map[string]any, len(paramsMap)) + + for key, value := range paramsMap { + if len(key) > 0 && key[0] != ':' { + modified[":"+key] = value + } else { + modified[key] = value + } + } + + return modified, nil + } + + return nil, nil +} + // luaSQLQuery executes a SQL query and returns results to Lua func luaSQLQuery(state *luajit.State) int { // Get database name @@ -212,18 +299,71 @@ func luaSQLQuery(state *luajit.State) int { } query := state.ToString(2) - // Get connection ID (optional for compatibility) + // Check if using positional parameters + isPositional := false + var positionalParams []any + + // Get connection ID (optional) var connID string - if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) { - connID = state.ToString(4) + + // Check if we have positional parameters instead of a params table + if state.GetTop() >= 3 && !state.IsTable(3) { + isPositional = true + paramCount := state.GetTop() - 2 // Count all args after db and query + + // Adjust connection ID index if we have positional params + if paramCount > 0 { + // Last parameter might be connID if it's a string + lastIdx := paramCount + 2 // db(1) + query(2) + paramCount + if state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) { + connID = state.ToString(lastIdx) + paramCount-- // Exclude connID from param count + } + } + + // Create array for positional parameters + positionalParams = make([]any, paramCount) + + // Collect all parameters + for i := 0; i < paramCount; i++ { + paramIdx := i + 3 // Params start at index 3 + + // Convert to appropriate Go value + var value any + switch state.GetType(paramIdx) { + case luajit.TypeNumber: + value = state.ToNumber(paramIdx) + case luajit.TypeString: + value = state.ToString(paramIdx) + case luajit.TypeBoolean: + value = state.ToBoolean(paramIdx) + case luajit.TypeNil: + value = nil + default: + // Try to convert as generic value + var err error + value, err = state.ToValue(paramIdx) + if err != nil { + state.PushString(fmt.Sprintf("sqlite.query: failed to convert parameter %d: %s", i+1, err.Error())) + return -1 + } + } + + positionalParams[i] = value + } } else { - // Generate a temporary connection ID - connID = fmt.Sprintf("temp_%p", &query) + // Original named parameter table handling + if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) { + connID = state.ToString(4) + } else { + // Generate a temporary connection ID + connID = fmt.Sprintf("temp_%p", &query) + } } - // Get parameters (optional) - var params map[string]any - if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) { + // Get parameters (optional for named parameters) + var params any + if !isPositional && state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) { var err error params, err = state.ToTable(3) if err != nil { @@ -240,20 +380,33 @@ func luaSQLQuery(state *luajit.State) int { } // For temporary connections, defer release - if !strings.HasPrefix(connID, "temp_") { - defer pool.Put(conn) + if strings.HasPrefix(connID, "temp_") { + defer func() { + // Release the connection + sqliteManager.mu.Lock() + delete(sqliteManager.activeConns, connID) + sqliteManager.mu.Unlock() - // Remove from active connections - sqliteManager.mu.Lock() - delete(sqliteManager.activeConns, connID) - sqliteManager.mu.Unlock() + pool.Put(conn) + }() } // Execute query and collect results var rows []map[string]any - err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{ - Named: prepareNamedParams(params), + // Prepare params based on type + namedParams, positional := prepareParams(params) + + // If we have direct positional params from function args, use those + if isPositional { + positional = positionalParams + } + + // Count actual placeholders in the query + placeholderCount := strings.Count(query, "?") + + // Execute with appropriate parameter type + execOpts := &sqlitex.ExecOptions{ ResultFunc: func(stmt *sqlite.Stmt) error { row := make(map[string]any) columnCount := stmt.ColumnCount() @@ -285,7 +438,20 @@ func luaSQLQuery(state *luajit.State) int { rows = append(rows, rowCopy) return nil }, - }) + } + + // Set appropriate parameter type + if namedParams != nil { + execOpts.Named = namedParams + } else if positional != nil { + // Make sure we're not passing more positional parameters than placeholders + if len(positional) > placeholderCount { + positional = positional[:placeholderCount] + } + execOpts.Args = positional + } + + err = sqlitex.Execute(conn, query, execOpts) if err != nil { state.PushString("sqlite.query: " + err.Error()) @@ -323,18 +489,71 @@ func luaSQLExec(state *luajit.State) int { } query := state.ToString(2) - // Get connection ID (optional for compatibility) + // Check if using positional parameters + isPositional := false + var positionalParams []any + + // Get connection ID (optional) var connID string - if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) { - connID = state.ToString(4) + + // Check if we have positional parameters instead of a params table + if state.GetTop() >= 3 && !state.IsTable(3) { + isPositional = true + paramCount := state.GetTop() - 2 // Count all args after db and query + + // Adjust connection ID index if we have positional params + if paramCount > 0 { + // Last parameter might be connID if it's a string + lastIdx := paramCount + 2 // db(1) + query(2) + paramCount + if state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) { + connID = state.ToString(lastIdx) + paramCount-- // Exclude connID from param count + } + } + + // Create array for positional parameters + positionalParams = make([]any, paramCount) + + // Collect all parameters + for i := 0; i < paramCount; i++ { + paramIdx := i + 3 // Params start at index 3 + + // Convert to appropriate Go value + var value any + switch state.GetType(paramIdx) { + case luajit.TypeNumber: + value = state.ToNumber(paramIdx) + case luajit.TypeString: + value = state.ToString(paramIdx) + case luajit.TypeBoolean: + value = state.ToBoolean(paramIdx) + case luajit.TypeNil: + value = nil + default: + // Try to convert as generic value + var err error + value, err = state.ToValue(paramIdx) + if err != nil { + state.PushString(fmt.Sprintf("sqlite.exec: failed to convert parameter %d: %s", i+1, err.Error())) + return -1 + } + } + + positionalParams[i] = value + } } else { - // Generate a temporary connection ID - connID = fmt.Sprintf("temp_%p", &query) + // Original named parameter table handling + if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) { + connID = state.ToString(4) + } else { + // Generate a temporary connection ID + connID = fmt.Sprintf("temp_%p", &query) + } } - // Get parameters (optional) - var params map[string]any - if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) { + // Get parameters (optional for named parameters) + var params any + if !isPositional && state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) { var err error params, err = state.ToTable(3) if err != nil { @@ -351,26 +570,55 @@ func luaSQLExec(state *luajit.State) int { } // For temporary connections, defer release - if !strings.HasPrefix(connID, "temp_") { - defer pool.Put(conn) + if strings.HasPrefix(connID, "temp_") { + defer func() { + // Release the connection + sqliteManager.mu.Lock() + delete(sqliteManager.activeConns, connID) + sqliteManager.mu.Unlock() - // Remove from active connections - sqliteManager.mu.Lock() - delete(sqliteManager.activeConns, connID) - sqliteManager.mu.Unlock() + pool.Put(conn) + }() } - // Execute statement - if params != nil { - err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{ - Named: prepareNamedParams(params), - }) + // Count actual placeholders in the query + placeholderCount := strings.Count(query, "?") + + // Prepare params based on type + namedParams, positional := prepareParams(params) + + // If we have direct positional params from function args, use those + if isPositional { + positional = positionalParams + } + + // Ensure we don't pass more parameters than placeholders + if positional != nil && len(positional) > placeholderCount { + positional = positional[:placeholderCount] + } + + // Execute with appropriate parameter type + var execErr error + + if isPositional || positional != nil { + // Execute with positional parameters + execOpts := &sqlitex.ExecOptions{ + Args: positional, + } + execErr = sqlitex.Execute(conn, query, execOpts) + } else if namedParams != nil { + // Execute with named parameters + execOpts := &sqlitex.ExecOptions{ + Named: namedParams, + } + execErr = sqlitex.Execute(conn, query, execOpts) } else { - err = sqlitex.ExecScript(conn, query) + // Execute without parameters + execErr = sqlitex.ExecScript(conn, query) } - if err != nil { - state.PushString("sqlite.exec: " + err.Error()) + if execErr != nil { + state.PushString("sqlite.exec: " + execErr.Error()) return -1 } @@ -391,21 +639,3 @@ func RegisterSQLiteFunctions(state *luajit.State) error { return nil } - -func prepareNamedParams(params map[string]any) map[string]any { - if params == nil { - return nil - } - - modified := make(map[string]any, len(params)) - - for key, value := range params { - if len(key) > 0 && key[0] != ':' { - modified[":"+key] = value - } else { - modified[key] = value - } - } - - return modified -} diff --git a/core/runner/sqlite.lua b/core/runner/sqlite.lua index 370bd67..002abeb 100644 --- a/core/runner/sqlite.lua +++ b/core/runner/sqlite.lua @@ -1,22 +1,70 @@ __active_sqlite_connections = {} +-- Helper function to handle parameters +local function handle_params(params, ...) + -- If params is a table, use it for named parameters + if type(params) == "table" then + return params + end + + -- If we have varargs, collect them for positional parameters + local args = {...} + if #args > 0 or params ~= nil then + -- Include the first param in the args + table.insert(args, 1, params) + return args + end + + return nil +end + -- Connection metatable local connection_mt = { __index = { -- Execute a query and return results as a table - query = function(self, query, params) + query = function(self, query, params, ...) if type(query) ~= "string" then error("connection:query: query must be a string", 2) end - return __sqlite_query(self.db_name, query, params) + + -- Handle params (named or positional) + local processed_params = handle_params(params, ...) + + -- Call with appropriate arguments + if type(processed_params) == "table" and processed_params[1] ~= nil then + -- Positional parameters - insert self.db_name and query at the beginning + table.insert(processed_params, 1, query) + table.insert(processed_params, 1, self.db_name) + -- Add connection ID at the end + table.insert(processed_params, self.id) + return __sqlite_query(unpack(processed_params)) + else + -- Named parameters or no parameters + return __sqlite_query(self.db_name, query, processed_params, self.id) + end end, -- Execute a statement and return affected rows - exec = function(self, query, params) + exec = function(self, query, params, ...) if type(query) ~= "string" then error("connection:exec: query must be a string", 2) end - return __sqlite_exec(self.db_name, query, params) + + -- Handle params (named or positional) + local processed_params = handle_params(params, ...) + + -- Call with appropriate arguments + if type(processed_params) == "table" and processed_params[1] ~= nil then + -- Positional parameters - insert self.db_name and query at the beginning + table.insert(processed_params, 1, query) + table.insert(processed_params, 1, self.db_name) + -- Add connection ID at the end + table.insert(processed_params, self.id) + return __sqlite_exec(unpack(processed_params)) + else + -- Named parameters or no parameters + return __sqlite_exec(self.db_name, query, processed_params, self.id) + end end, -- Create a new table @@ -121,8 +169,14 @@ local connection_mt = { end, -- Get one row - get_one = function(self, query, params) - local results = self:query(query, params) + get_one = function(self, query, params, ...) + -- Handle both named and positional parameters + local results + if select('#', ...) > 0 then + results = self:query(query, params, ...) + else + results = self:query(query, params) + end return results[1] end,