add positional parameter support

This commit is contained in:
Sky Johnson 2025-05-03 16:10:40 -05:00
parent c005066816
commit 551f311755
3 changed files with 357 additions and 88 deletions

View File

@ -29,7 +29,7 @@ func passwordHash(state *luajit.State) int {
params := &argon2id.Params{ params := &argon2id.Params{
Memory: 64 * 1024, Memory: 64 * 1024,
Iterations: 3, Iterations: 4,
Parallelism: 4, Parallelism: 4,
SaltLength: 16, SaltLength: 16,
KeyLength: 32, KeyLength: 32,
@ -38,46 +38,31 @@ func passwordHash(state *luajit.State) int {
if state.IsTable(2) { if state.IsTable(2) {
state.GetField(2, "memory") state.GetField(2, "memory")
if state.IsNumber(-1) { if state.IsNumber(-1) {
params.Memory = uint32(state.ToNumber(-1)) params.Memory = max(uint32(state.ToNumber(-1)), 8*1024)
if params.Memory < 8*1024 {
params.Memory = 8 * 1024 // Minimum 8MB
}
} }
state.Pop(1) state.Pop(1)
state.GetField(2, "iterations") state.GetField(2, "iterations")
if state.IsNumber(-1) { if state.IsNumber(-1) {
params.Iterations = uint32(state.ToNumber(-1)) params.Iterations = max(uint32(state.ToNumber(-1)), 1)
if params.Iterations < 1 {
params.Iterations = 1 // Minimum 1 iteration
}
} }
state.Pop(1) state.Pop(1)
state.GetField(2, "parallelism") state.GetField(2, "parallelism")
if state.IsNumber(-1) { if state.IsNumber(-1) {
params.Parallelism = uint8(state.ToNumber(-1)) params.Parallelism = max(uint8(state.ToNumber(-1)), 1)
if params.Parallelism < 1 {
params.Parallelism = 1 // Minimum 1 thread
}
} }
state.Pop(1) state.Pop(1)
state.GetField(2, "salt_length") state.GetField(2, "salt_length")
if state.IsNumber(-1) { if state.IsNumber(-1) {
params.SaltLength = uint32(state.ToNumber(-1)) params.SaltLength = max(uint32(state.ToNumber(-1)), 8)
if params.SaltLength < 8 {
params.SaltLength = 8 // Minimum 8 bytes
}
} }
state.Pop(1) state.Pop(1)
state.GetField(2, "key_length") state.GetField(2, "key_length")
if state.IsNumber(-1) { if state.IsNumber(-1) {
params.KeyLength = uint32(state.ToNumber(-1)) params.KeyLength = max(uint32(state.ToNumber(-1)), 16)
if params.KeyLength < 16 {
params.KeyLength = 16 // Minimum 16 bytes
}
} }
state.Pop(1) state.Pop(1)
} }

View File

@ -1,6 +1,7 @@
package runner package runner
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"path/filepath" "path/filepath"
@ -178,10 +179,10 @@ func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, e
return nil, nil, err return nil, nil, err
} }
// Get a connection // Get a connection using the newer Take API
dbConn := pool.Get(nil) dbConn, err := pool.Take(context.Background())
if dbConn == nil { if err != nil {
return nil, nil, errors.New("failed to get connection from pool") return nil, nil, fmt.Errorf("failed to get connection from pool: %w", err)
} }
// Store connection // Store connection
@ -196,6 +197,92 @@ func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, e
return dbConn, pool, nil return dbConn, pool, nil
} }
// detectParamType determines if parameters are positional or named
func detectParamType(params any) (isArray bool) {
if params == nil {
return false
}
// Check if it's a map[string]any
if paramsMap, ok := params.(map[string]any); ok {
// Check for the empty string key which indicates an array
if array, hasArray := paramsMap[""]; hasArray {
// Verify it's actually an array
if _, isSlice := array.([]any); isSlice {
return true
}
if _, isFloatSlice := array.([]float64); isFloatSlice {
return true
}
}
return false
}
// If it's already a slice type
if _, ok := params.([]any); ok {
return true
}
if _, ok := params.([]float64); ok {
return true
}
return false
}
// prepareParams processes parameters for SQLite queries
func prepareParams(params any) (map[string]any, []any) {
if params == nil {
return nil, nil
}
// Handle positional parameters (array-like)
if detectParamType(params) {
var positional []any
// Extract array from special map format
if paramsMap, ok := params.(map[string]any); ok {
if array, hasArray := paramsMap[""]; hasArray {
if slice, ok := array.([]any); ok {
positional = slice
} else if floatSlice, ok := array.([]float64); ok {
// Convert []float64 to []any
positional = make([]any, len(floatSlice))
for i, v := range floatSlice {
positional[i] = v
}
}
}
} else if slice, ok := params.([]any); ok {
positional = slice
} else if floatSlice, ok := params.([]float64); ok {
// Convert []float64 to []any
positional = make([]any, len(floatSlice))
for i, v := range floatSlice {
positional[i] = v
}
}
return nil, positional
}
// Handle named parameters (map-like)
if paramsMap, ok := params.(map[string]any); ok {
modified := make(map[string]any, len(paramsMap))
for key, value := range paramsMap {
if len(key) > 0 && key[0] != ':' {
modified[":"+key] = value
} else {
modified[key] = value
}
}
return modified, nil
}
return nil, nil
}
// luaSQLQuery executes a SQL query and returns results to Lua // luaSQLQuery executes a SQL query and returns results to Lua
func luaSQLQuery(state *luajit.State) int { func luaSQLQuery(state *luajit.State) int {
// Get database name // Get database name
@ -212,18 +299,71 @@ func luaSQLQuery(state *luajit.State) int {
} }
query := state.ToString(2) query := state.ToString(2)
// Get connection ID (optional for compatibility) // Check if using positional parameters
isPositional := false
var positionalParams []any
// Get connection ID (optional)
var connID string var connID string
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
connID = state.ToString(4) // Check if we have positional parameters instead of a params table
if state.GetTop() >= 3 && !state.IsTable(3) {
isPositional = true
paramCount := state.GetTop() - 2 // Count all args after db and query
// Adjust connection ID index if we have positional params
if paramCount > 0 {
// Last parameter might be connID if it's a string
lastIdx := paramCount + 2 // db(1) + query(2) + paramCount
if state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) {
connID = state.ToString(lastIdx)
paramCount-- // Exclude connID from param count
}
}
// Create array for positional parameters
positionalParams = make([]any, paramCount)
// Collect all parameters
for i := 0; i < paramCount; i++ {
paramIdx := i + 3 // Params start at index 3
// Convert to appropriate Go value
var value any
switch state.GetType(paramIdx) {
case luajit.TypeNumber:
value = state.ToNumber(paramIdx)
case luajit.TypeString:
value = state.ToString(paramIdx)
case luajit.TypeBoolean:
value = state.ToBoolean(paramIdx)
case luajit.TypeNil:
value = nil
default:
// Try to convert as generic value
var err error
value, err = state.ToValue(paramIdx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: failed to convert parameter %d: %s", i+1, err.Error()))
return -1
}
}
positionalParams[i] = value
}
} else { } else {
// Generate a temporary connection ID // Original named parameter table handling
connID = fmt.Sprintf("temp_%p", &query) if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
connID = state.ToString(4)
} else {
// Generate a temporary connection ID
connID = fmt.Sprintf("temp_%p", &query)
}
} }
// Get parameters (optional) // Get parameters (optional for named parameters)
var params map[string]any var params any
if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) { if !isPositional && state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
var err error var err error
params, err = state.ToTable(3) params, err = state.ToTable(3)
if err != nil { if err != nil {
@ -240,20 +380,33 @@ func luaSQLQuery(state *luajit.State) int {
} }
// For temporary connections, defer release // For temporary connections, defer release
if !strings.HasPrefix(connID, "temp_") { if strings.HasPrefix(connID, "temp_") {
defer pool.Put(conn) defer func() {
// Release the connection
sqliteManager.mu.Lock()
delete(sqliteManager.activeConns, connID)
sqliteManager.mu.Unlock()
// Remove from active connections pool.Put(conn)
sqliteManager.mu.Lock() }()
delete(sqliteManager.activeConns, connID)
sqliteManager.mu.Unlock()
} }
// Execute query and collect results // Execute query and collect results
var rows []map[string]any var rows []map[string]any
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{ // Prepare params based on type
Named: prepareNamedParams(params), namedParams, positional := prepareParams(params)
// If we have direct positional params from function args, use those
if isPositional {
positional = positionalParams
}
// Count actual placeholders in the query
placeholderCount := strings.Count(query, "?")
// Execute with appropriate parameter type
execOpts := &sqlitex.ExecOptions{
ResultFunc: func(stmt *sqlite.Stmt) error { ResultFunc: func(stmt *sqlite.Stmt) error {
row := make(map[string]any) row := make(map[string]any)
columnCount := stmt.ColumnCount() columnCount := stmt.ColumnCount()
@ -285,7 +438,20 @@ func luaSQLQuery(state *luajit.State) int {
rows = append(rows, rowCopy) rows = append(rows, rowCopy)
return nil return nil
}, },
}) }
// Set appropriate parameter type
if namedParams != nil {
execOpts.Named = namedParams
} else if positional != nil {
// Make sure we're not passing more positional parameters than placeholders
if len(positional) > placeholderCount {
positional = positional[:placeholderCount]
}
execOpts.Args = positional
}
err = sqlitex.Execute(conn, query, execOpts)
if err != nil { if err != nil {
state.PushString("sqlite.query: " + err.Error()) state.PushString("sqlite.query: " + err.Error())
@ -323,18 +489,71 @@ func luaSQLExec(state *luajit.State) int {
} }
query := state.ToString(2) query := state.ToString(2)
// Get connection ID (optional for compatibility) // Check if using positional parameters
isPositional := false
var positionalParams []any
// Get connection ID (optional)
var connID string var connID string
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
connID = state.ToString(4) // Check if we have positional parameters instead of a params table
if state.GetTop() >= 3 && !state.IsTable(3) {
isPositional = true
paramCount := state.GetTop() - 2 // Count all args after db and query
// Adjust connection ID index if we have positional params
if paramCount > 0 {
// Last parameter might be connID if it's a string
lastIdx := paramCount + 2 // db(1) + query(2) + paramCount
if state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) {
connID = state.ToString(lastIdx)
paramCount-- // Exclude connID from param count
}
}
// Create array for positional parameters
positionalParams = make([]any, paramCount)
// Collect all parameters
for i := 0; i < paramCount; i++ {
paramIdx := i + 3 // Params start at index 3
// Convert to appropriate Go value
var value any
switch state.GetType(paramIdx) {
case luajit.TypeNumber:
value = state.ToNumber(paramIdx)
case luajit.TypeString:
value = state.ToString(paramIdx)
case luajit.TypeBoolean:
value = state.ToBoolean(paramIdx)
case luajit.TypeNil:
value = nil
default:
// Try to convert as generic value
var err error
value, err = state.ToValue(paramIdx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: failed to convert parameter %d: %s", i+1, err.Error()))
return -1
}
}
positionalParams[i] = value
}
} else { } else {
// Generate a temporary connection ID // Original named parameter table handling
connID = fmt.Sprintf("temp_%p", &query) if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
connID = state.ToString(4)
} else {
// Generate a temporary connection ID
connID = fmt.Sprintf("temp_%p", &query)
}
} }
// Get parameters (optional) // Get parameters (optional for named parameters)
var params map[string]any var params any
if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) { if !isPositional && state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
var err error var err error
params, err = state.ToTable(3) params, err = state.ToTable(3)
if err != nil { if err != nil {
@ -351,26 +570,55 @@ func luaSQLExec(state *luajit.State) int {
} }
// For temporary connections, defer release // For temporary connections, defer release
if !strings.HasPrefix(connID, "temp_") { if strings.HasPrefix(connID, "temp_") {
defer pool.Put(conn) defer func() {
// Release the connection
sqliteManager.mu.Lock()
delete(sqliteManager.activeConns, connID)
sqliteManager.mu.Unlock()
// Remove from active connections pool.Put(conn)
sqliteManager.mu.Lock() }()
delete(sqliteManager.activeConns, connID)
sqliteManager.mu.Unlock()
} }
// Execute statement // Count actual placeholders in the query
if params != nil { placeholderCount := strings.Count(query, "?")
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
Named: prepareNamedParams(params), // Prepare params based on type
}) namedParams, positional := prepareParams(params)
// If we have direct positional params from function args, use those
if isPositional {
positional = positionalParams
}
// Ensure we don't pass more parameters than placeholders
if positional != nil && len(positional) > placeholderCount {
positional = positional[:placeholderCount]
}
// Execute with appropriate parameter type
var execErr error
if isPositional || positional != nil {
// Execute with positional parameters
execOpts := &sqlitex.ExecOptions{
Args: positional,
}
execErr = sqlitex.Execute(conn, query, execOpts)
} else if namedParams != nil {
// Execute with named parameters
execOpts := &sqlitex.ExecOptions{
Named: namedParams,
}
execErr = sqlitex.Execute(conn, query, execOpts)
} else { } else {
err = sqlitex.ExecScript(conn, query) // Execute without parameters
execErr = sqlitex.ExecScript(conn, query)
} }
if err != nil { if execErr != nil {
state.PushString("sqlite.exec: " + err.Error()) state.PushString("sqlite.exec: " + execErr.Error())
return -1 return -1
} }
@ -391,21 +639,3 @@ func RegisterSQLiteFunctions(state *luajit.State) error {
return nil return nil
} }
func prepareNamedParams(params map[string]any) map[string]any {
if params == nil {
return nil
}
modified := make(map[string]any, len(params))
for key, value := range params {
if len(key) > 0 && key[0] != ':' {
modified[":"+key] = value
} else {
modified[key] = value
}
}
return modified
}

View File

@ -1,22 +1,70 @@
__active_sqlite_connections = {} __active_sqlite_connections = {}
-- Helper function to handle parameters
local function handle_params(params, ...)
-- If params is a table, use it for named parameters
if type(params) == "table" then
return params
end
-- If we have varargs, collect them for positional parameters
local args = {...}
if #args > 0 or params ~= nil then
-- Include the first param in the args
table.insert(args, 1, params)
return args
end
return nil
end
-- Connection metatable -- Connection metatable
local connection_mt = { local connection_mt = {
__index = { __index = {
-- Execute a query and return results as a table -- Execute a query and return results as a table
query = function(self, query, params) query = function(self, query, params, ...)
if type(query) ~= "string" then if type(query) ~= "string" then
error("connection:query: query must be a string", 2) error("connection:query: query must be a string", 2)
end end
return __sqlite_query(self.db_name, query, params)
-- Handle params (named or positional)
local processed_params = handle_params(params, ...)
-- Call with appropriate arguments
if type(processed_params) == "table" and processed_params[1] ~= nil then
-- Positional parameters - insert self.db_name and query at the beginning
table.insert(processed_params, 1, query)
table.insert(processed_params, 1, self.db_name)
-- Add connection ID at the end
table.insert(processed_params, self.id)
return __sqlite_query(unpack(processed_params))
else
-- Named parameters or no parameters
return __sqlite_query(self.db_name, query, processed_params, self.id)
end
end, end,
-- Execute a statement and return affected rows -- Execute a statement and return affected rows
exec = function(self, query, params) exec = function(self, query, params, ...)
if type(query) ~= "string" then if type(query) ~= "string" then
error("connection:exec: query must be a string", 2) error("connection:exec: query must be a string", 2)
end end
return __sqlite_exec(self.db_name, query, params)
-- Handle params (named or positional)
local processed_params = handle_params(params, ...)
-- Call with appropriate arguments
if type(processed_params) == "table" and processed_params[1] ~= nil then
-- Positional parameters - insert self.db_name and query at the beginning
table.insert(processed_params, 1, query)
table.insert(processed_params, 1, self.db_name)
-- Add connection ID at the end
table.insert(processed_params, self.id)
return __sqlite_exec(unpack(processed_params))
else
-- Named parameters or no parameters
return __sqlite_exec(self.db_name, query, processed_params, self.id)
end
end, end,
-- Create a new table -- Create a new table
@ -121,8 +169,14 @@ local connection_mt = {
end, end,
-- Get one row -- Get one row
get_one = function(self, query, params) get_one = function(self, query, params, ...)
local results = self:query(query, params) -- Handle both named and positional parameters
local results
if select('#', ...) > 0 then
results = self:query(query, params, ...)
else
results = self:query(query, params)
end
return results[1] return results[1]
end, end,