add positional parameter support

This commit is contained in:
Sky Johnson 2025-05-03 16:10:40 -05:00
parent c005066816
commit 551f311755
3 changed files with 357 additions and 88 deletions

View File

@ -29,7 +29,7 @@ func passwordHash(state *luajit.State) int {
params := &argon2id.Params{
Memory: 64 * 1024,
Iterations: 3,
Iterations: 4,
Parallelism: 4,
SaltLength: 16,
KeyLength: 32,
@ -38,46 +38,31 @@ func passwordHash(state *luajit.State) int {
if state.IsTable(2) {
state.GetField(2, "memory")
if state.IsNumber(-1) {
params.Memory = uint32(state.ToNumber(-1))
if params.Memory < 8*1024 {
params.Memory = 8 * 1024 // Minimum 8MB
}
params.Memory = max(uint32(state.ToNumber(-1)), 8*1024)
}
state.Pop(1)
state.GetField(2, "iterations")
if state.IsNumber(-1) {
params.Iterations = uint32(state.ToNumber(-1))
if params.Iterations < 1 {
params.Iterations = 1 // Minimum 1 iteration
}
params.Iterations = max(uint32(state.ToNumber(-1)), 1)
}
state.Pop(1)
state.GetField(2, "parallelism")
if state.IsNumber(-1) {
params.Parallelism = uint8(state.ToNumber(-1))
if params.Parallelism < 1 {
params.Parallelism = 1 // Minimum 1 thread
}
params.Parallelism = max(uint8(state.ToNumber(-1)), 1)
}
state.Pop(1)
state.GetField(2, "salt_length")
if state.IsNumber(-1) {
params.SaltLength = uint32(state.ToNumber(-1))
if params.SaltLength < 8 {
params.SaltLength = 8 // Minimum 8 bytes
}
params.SaltLength = max(uint32(state.ToNumber(-1)), 8)
}
state.Pop(1)
state.GetField(2, "key_length")
if state.IsNumber(-1) {
params.KeyLength = uint32(state.ToNumber(-1))
if params.KeyLength < 16 {
params.KeyLength = 16 // Minimum 16 bytes
}
params.KeyLength = max(uint32(state.ToNumber(-1)), 16)
}
state.Pop(1)
}

View File

@ -1,6 +1,7 @@
package runner
import (
"context"
"errors"
"fmt"
"path/filepath"
@ -178,10 +179,10 @@ func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, e
return nil, nil, err
}
// Get a connection
dbConn := pool.Get(nil)
if dbConn == nil {
return nil, nil, errors.New("failed to get connection from pool")
// Get a connection using the newer Take API
dbConn, err := pool.Take(context.Background())
if err != nil {
return nil, nil, fmt.Errorf("failed to get connection from pool: %w", err)
}
// Store connection
@ -196,6 +197,92 @@ func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, e
return dbConn, pool, nil
}
// detectParamType determines if parameters are positional or named
func detectParamType(params any) (isArray bool) {
if params == nil {
return false
}
// Check if it's a map[string]any
if paramsMap, ok := params.(map[string]any); ok {
// Check for the empty string key which indicates an array
if array, hasArray := paramsMap[""]; hasArray {
// Verify it's actually an array
if _, isSlice := array.([]any); isSlice {
return true
}
if _, isFloatSlice := array.([]float64); isFloatSlice {
return true
}
}
return false
}
// If it's already a slice type
if _, ok := params.([]any); ok {
return true
}
if _, ok := params.([]float64); ok {
return true
}
return false
}
// prepareParams processes parameters for SQLite queries
func prepareParams(params any) (map[string]any, []any) {
if params == nil {
return nil, nil
}
// Handle positional parameters (array-like)
if detectParamType(params) {
var positional []any
// Extract array from special map format
if paramsMap, ok := params.(map[string]any); ok {
if array, hasArray := paramsMap[""]; hasArray {
if slice, ok := array.([]any); ok {
positional = slice
} else if floatSlice, ok := array.([]float64); ok {
// Convert []float64 to []any
positional = make([]any, len(floatSlice))
for i, v := range floatSlice {
positional[i] = v
}
}
}
} else if slice, ok := params.([]any); ok {
positional = slice
} else if floatSlice, ok := params.([]float64); ok {
// Convert []float64 to []any
positional = make([]any, len(floatSlice))
for i, v := range floatSlice {
positional[i] = v
}
}
return nil, positional
}
// Handle named parameters (map-like)
if paramsMap, ok := params.(map[string]any); ok {
modified := make(map[string]any, len(paramsMap))
for key, value := range paramsMap {
if len(key) > 0 && key[0] != ':' {
modified[":"+key] = value
} else {
modified[key] = value
}
}
return modified, nil
}
return nil, nil
}
// luaSQLQuery executes a SQL query and returns results to Lua
func luaSQLQuery(state *luajit.State) int {
// Get database name
@ -212,18 +299,71 @@ func luaSQLQuery(state *luajit.State) int {
}
query := state.ToString(2)
// Get connection ID (optional for compatibility)
// Check if using positional parameters
isPositional := false
var positionalParams []any
// Get connection ID (optional)
var connID string
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
connID = state.ToString(4)
// Check if we have positional parameters instead of a params table
if state.GetTop() >= 3 && !state.IsTable(3) {
isPositional = true
paramCount := state.GetTop() - 2 // Count all args after db and query
// Adjust connection ID index if we have positional params
if paramCount > 0 {
// Last parameter might be connID if it's a string
lastIdx := paramCount + 2 // db(1) + query(2) + paramCount
if state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) {
connID = state.ToString(lastIdx)
paramCount-- // Exclude connID from param count
}
}
// Create array for positional parameters
positionalParams = make([]any, paramCount)
// Collect all parameters
for i := 0; i < paramCount; i++ {
paramIdx := i + 3 // Params start at index 3
// Convert to appropriate Go value
var value any
switch state.GetType(paramIdx) {
case luajit.TypeNumber:
value = state.ToNumber(paramIdx)
case luajit.TypeString:
value = state.ToString(paramIdx)
case luajit.TypeBoolean:
value = state.ToBoolean(paramIdx)
case luajit.TypeNil:
value = nil
default:
// Try to convert as generic value
var err error
value, err = state.ToValue(paramIdx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: failed to convert parameter %d: %s", i+1, err.Error()))
return -1
}
}
positionalParams[i] = value
}
} else {
// Generate a temporary connection ID
connID = fmt.Sprintf("temp_%p", &query)
// Original named parameter table handling
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
connID = state.ToString(4)
} else {
// Generate a temporary connection ID
connID = fmt.Sprintf("temp_%p", &query)
}
}
// Get parameters (optional)
var params map[string]any
if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
// Get parameters (optional for named parameters)
var params any
if !isPositional && state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
var err error
params, err = state.ToTable(3)
if err != nil {
@ -240,20 +380,33 @@ func luaSQLQuery(state *luajit.State) int {
}
// For temporary connections, defer release
if !strings.HasPrefix(connID, "temp_") {
defer pool.Put(conn)
if strings.HasPrefix(connID, "temp_") {
defer func() {
// Release the connection
sqliteManager.mu.Lock()
delete(sqliteManager.activeConns, connID)
sqliteManager.mu.Unlock()
// Remove from active connections
sqliteManager.mu.Lock()
delete(sqliteManager.activeConns, connID)
sqliteManager.mu.Unlock()
pool.Put(conn)
}()
}
// Execute query and collect results
var rows []map[string]any
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
Named: prepareNamedParams(params),
// Prepare params based on type
namedParams, positional := prepareParams(params)
// If we have direct positional params from function args, use those
if isPositional {
positional = positionalParams
}
// Count actual placeholders in the query
placeholderCount := strings.Count(query, "?")
// Execute with appropriate parameter type
execOpts := &sqlitex.ExecOptions{
ResultFunc: func(stmt *sqlite.Stmt) error {
row := make(map[string]any)
columnCount := stmt.ColumnCount()
@ -285,7 +438,20 @@ func luaSQLQuery(state *luajit.State) int {
rows = append(rows, rowCopy)
return nil
},
})
}
// Set appropriate parameter type
if namedParams != nil {
execOpts.Named = namedParams
} else if positional != nil {
// Make sure we're not passing more positional parameters than placeholders
if len(positional) > placeholderCount {
positional = positional[:placeholderCount]
}
execOpts.Args = positional
}
err = sqlitex.Execute(conn, query, execOpts)
if err != nil {
state.PushString("sqlite.query: " + err.Error())
@ -323,18 +489,71 @@ func luaSQLExec(state *luajit.State) int {
}
query := state.ToString(2)
// Get connection ID (optional for compatibility)
// Check if using positional parameters
isPositional := false
var positionalParams []any
// Get connection ID (optional)
var connID string
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
connID = state.ToString(4)
// Check if we have positional parameters instead of a params table
if state.GetTop() >= 3 && !state.IsTable(3) {
isPositional = true
paramCount := state.GetTop() - 2 // Count all args after db and query
// Adjust connection ID index if we have positional params
if paramCount > 0 {
// Last parameter might be connID if it's a string
lastIdx := paramCount + 2 // db(1) + query(2) + paramCount
if state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) {
connID = state.ToString(lastIdx)
paramCount-- // Exclude connID from param count
}
}
// Create array for positional parameters
positionalParams = make([]any, paramCount)
// Collect all parameters
for i := 0; i < paramCount; i++ {
paramIdx := i + 3 // Params start at index 3
// Convert to appropriate Go value
var value any
switch state.GetType(paramIdx) {
case luajit.TypeNumber:
value = state.ToNumber(paramIdx)
case luajit.TypeString:
value = state.ToString(paramIdx)
case luajit.TypeBoolean:
value = state.ToBoolean(paramIdx)
case luajit.TypeNil:
value = nil
default:
// Try to convert as generic value
var err error
value, err = state.ToValue(paramIdx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: failed to convert parameter %d: %s", i+1, err.Error()))
return -1
}
}
positionalParams[i] = value
}
} else {
// Generate a temporary connection ID
connID = fmt.Sprintf("temp_%p", &query)
// Original named parameter table handling
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
connID = state.ToString(4)
} else {
// Generate a temporary connection ID
connID = fmt.Sprintf("temp_%p", &query)
}
}
// Get parameters (optional)
var params map[string]any
if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
// Get parameters (optional for named parameters)
var params any
if !isPositional && state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
var err error
params, err = state.ToTable(3)
if err != nil {
@ -351,26 +570,55 @@ func luaSQLExec(state *luajit.State) int {
}
// For temporary connections, defer release
if !strings.HasPrefix(connID, "temp_") {
defer pool.Put(conn)
if strings.HasPrefix(connID, "temp_") {
defer func() {
// Release the connection
sqliteManager.mu.Lock()
delete(sqliteManager.activeConns, connID)
sqliteManager.mu.Unlock()
// Remove from active connections
sqliteManager.mu.Lock()
delete(sqliteManager.activeConns, connID)
sqliteManager.mu.Unlock()
pool.Put(conn)
}()
}
// Execute statement
if params != nil {
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
Named: prepareNamedParams(params),
})
// Count actual placeholders in the query
placeholderCount := strings.Count(query, "?")
// Prepare params based on type
namedParams, positional := prepareParams(params)
// If we have direct positional params from function args, use those
if isPositional {
positional = positionalParams
}
// Ensure we don't pass more parameters than placeholders
if positional != nil && len(positional) > placeholderCount {
positional = positional[:placeholderCount]
}
// Execute with appropriate parameter type
var execErr error
if isPositional || positional != nil {
// Execute with positional parameters
execOpts := &sqlitex.ExecOptions{
Args: positional,
}
execErr = sqlitex.Execute(conn, query, execOpts)
} else if namedParams != nil {
// Execute with named parameters
execOpts := &sqlitex.ExecOptions{
Named: namedParams,
}
execErr = sqlitex.Execute(conn, query, execOpts)
} else {
err = sqlitex.ExecScript(conn, query)
// Execute without parameters
execErr = sqlitex.ExecScript(conn, query)
}
if err != nil {
state.PushString("sqlite.exec: " + err.Error())
if execErr != nil {
state.PushString("sqlite.exec: " + execErr.Error())
return -1
}
@ -391,21 +639,3 @@ func RegisterSQLiteFunctions(state *luajit.State) error {
return nil
}
func prepareNamedParams(params map[string]any) map[string]any {
if params == nil {
return nil
}
modified := make(map[string]any, len(params))
for key, value := range params {
if len(key) > 0 && key[0] != ':' {
modified[":"+key] = value
} else {
modified[key] = value
}
}
return modified
}

View File

@ -1,22 +1,70 @@
__active_sqlite_connections = {}
-- Helper function to handle parameters
local function handle_params(params, ...)
-- If params is a table, use it for named parameters
if type(params) == "table" then
return params
end
-- If we have varargs, collect them for positional parameters
local args = {...}
if #args > 0 or params ~= nil then
-- Include the first param in the args
table.insert(args, 1, params)
return args
end
return nil
end
-- Connection metatable
local connection_mt = {
__index = {
-- Execute a query and return results as a table
query = function(self, query, params)
query = function(self, query, params, ...)
if type(query) ~= "string" then
error("connection:query: query must be a string", 2)
end
return __sqlite_query(self.db_name, query, params)
-- Handle params (named or positional)
local processed_params = handle_params(params, ...)
-- Call with appropriate arguments
if type(processed_params) == "table" and processed_params[1] ~= nil then
-- Positional parameters - insert self.db_name and query at the beginning
table.insert(processed_params, 1, query)
table.insert(processed_params, 1, self.db_name)
-- Add connection ID at the end
table.insert(processed_params, self.id)
return __sqlite_query(unpack(processed_params))
else
-- Named parameters or no parameters
return __sqlite_query(self.db_name, query, processed_params, self.id)
end
end,
-- Execute a statement and return affected rows
exec = function(self, query, params)
exec = function(self, query, params, ...)
if type(query) ~= "string" then
error("connection:exec: query must be a string", 2)
end
return __sqlite_exec(self.db_name, query, params)
-- Handle params (named or positional)
local processed_params = handle_params(params, ...)
-- Call with appropriate arguments
if type(processed_params) == "table" and processed_params[1] ~= nil then
-- Positional parameters - insert self.db_name and query at the beginning
table.insert(processed_params, 1, query)
table.insert(processed_params, 1, self.db_name)
-- Add connection ID at the end
table.insert(processed_params, self.id)
return __sqlite_exec(unpack(processed_params))
else
-- Named parameters or no parameters
return __sqlite_exec(self.db_name, query, processed_params, self.id)
end
end,
-- Create a new table
@ -121,8 +169,14 @@ local connection_mt = {
end,
-- Get one row
get_one = function(self, query, params)
local results = self:query(query, params)
get_one = function(self, query, params, ...)
-- Handle both named and positional parameters
local results
if select('#', ...) > 0 then
results = self:query(query, params, ...)
else
results = self:query(query, params)
end
return results[1]
end,