add positional parameter support
This commit is contained in:
parent
c005066816
commit
551f311755
@ -29,7 +29,7 @@ func passwordHash(state *luajit.State) int {
|
|||||||
|
|
||||||
params := &argon2id.Params{
|
params := &argon2id.Params{
|
||||||
Memory: 64 * 1024,
|
Memory: 64 * 1024,
|
||||||
Iterations: 3,
|
Iterations: 4,
|
||||||
Parallelism: 4,
|
Parallelism: 4,
|
||||||
SaltLength: 16,
|
SaltLength: 16,
|
||||||
KeyLength: 32,
|
KeyLength: 32,
|
||||||
@ -38,46 +38,31 @@ func passwordHash(state *luajit.State) int {
|
|||||||
if state.IsTable(2) {
|
if state.IsTable(2) {
|
||||||
state.GetField(2, "memory")
|
state.GetField(2, "memory")
|
||||||
if state.IsNumber(-1) {
|
if state.IsNumber(-1) {
|
||||||
params.Memory = uint32(state.ToNumber(-1))
|
params.Memory = max(uint32(state.ToNumber(-1)), 8*1024)
|
||||||
if params.Memory < 8*1024 {
|
|
||||||
params.Memory = 8 * 1024 // Minimum 8MB
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
state.Pop(1)
|
||||||
|
|
||||||
state.GetField(2, "iterations")
|
state.GetField(2, "iterations")
|
||||||
if state.IsNumber(-1) {
|
if state.IsNumber(-1) {
|
||||||
params.Iterations = uint32(state.ToNumber(-1))
|
params.Iterations = max(uint32(state.ToNumber(-1)), 1)
|
||||||
if params.Iterations < 1 {
|
|
||||||
params.Iterations = 1 // Minimum 1 iteration
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
state.Pop(1)
|
||||||
|
|
||||||
state.GetField(2, "parallelism")
|
state.GetField(2, "parallelism")
|
||||||
if state.IsNumber(-1) {
|
if state.IsNumber(-1) {
|
||||||
params.Parallelism = uint8(state.ToNumber(-1))
|
params.Parallelism = max(uint8(state.ToNumber(-1)), 1)
|
||||||
if params.Parallelism < 1 {
|
|
||||||
params.Parallelism = 1 // Minimum 1 thread
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
state.Pop(1)
|
||||||
|
|
||||||
state.GetField(2, "salt_length")
|
state.GetField(2, "salt_length")
|
||||||
if state.IsNumber(-1) {
|
if state.IsNumber(-1) {
|
||||||
params.SaltLength = uint32(state.ToNumber(-1))
|
params.SaltLength = max(uint32(state.ToNumber(-1)), 8)
|
||||||
if params.SaltLength < 8 {
|
|
||||||
params.SaltLength = 8 // Minimum 8 bytes
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
state.Pop(1)
|
||||||
|
|
||||||
state.GetField(2, "key_length")
|
state.GetField(2, "key_length")
|
||||||
if state.IsNumber(-1) {
|
if state.IsNumber(-1) {
|
||||||
params.KeyLength = uint32(state.ToNumber(-1))
|
params.KeyLength = max(uint32(state.ToNumber(-1)), 16)
|
||||||
if params.KeyLength < 16 {
|
|
||||||
params.KeyLength = 16 // Minimum 16 bytes
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
state.Pop(1)
|
state.Pop(1)
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package runner
|
package runner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -178,10 +179,10 @@ func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, e
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get a connection
|
// Get a connection using the newer Take API
|
||||||
dbConn := pool.Get(nil)
|
dbConn, err := pool.Take(context.Background())
|
||||||
if dbConn == nil {
|
if err != nil {
|
||||||
return nil, nil, errors.New("failed to get connection from pool")
|
return nil, nil, fmt.Errorf("failed to get connection from pool: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store connection
|
// Store connection
|
||||||
@ -196,6 +197,92 @@ func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, e
|
|||||||
return dbConn, pool, nil
|
return dbConn, pool, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// detectParamType determines if parameters are positional or named
|
||||||
|
func detectParamType(params any) (isArray bool) {
|
||||||
|
if params == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's a map[string]any
|
||||||
|
if paramsMap, ok := params.(map[string]any); ok {
|
||||||
|
// Check for the empty string key which indicates an array
|
||||||
|
if array, hasArray := paramsMap[""]; hasArray {
|
||||||
|
// Verify it's actually an array
|
||||||
|
if _, isSlice := array.([]any); isSlice {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if _, isFloatSlice := array.([]float64); isFloatSlice {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's already a slice type
|
||||||
|
if _, ok := params.([]any); ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if _, ok := params.([]float64); ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepareParams processes parameters for SQLite queries
|
||||||
|
func prepareParams(params any) (map[string]any, []any) {
|
||||||
|
if params == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle positional parameters (array-like)
|
||||||
|
if detectParamType(params) {
|
||||||
|
var positional []any
|
||||||
|
|
||||||
|
// Extract array from special map format
|
||||||
|
if paramsMap, ok := params.(map[string]any); ok {
|
||||||
|
if array, hasArray := paramsMap[""]; hasArray {
|
||||||
|
if slice, ok := array.([]any); ok {
|
||||||
|
positional = slice
|
||||||
|
} else if floatSlice, ok := array.([]float64); ok {
|
||||||
|
// Convert []float64 to []any
|
||||||
|
positional = make([]any, len(floatSlice))
|
||||||
|
for i, v := range floatSlice {
|
||||||
|
positional[i] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if slice, ok := params.([]any); ok {
|
||||||
|
positional = slice
|
||||||
|
} else if floatSlice, ok := params.([]float64); ok {
|
||||||
|
// Convert []float64 to []any
|
||||||
|
positional = make([]any, len(floatSlice))
|
||||||
|
for i, v := range floatSlice {
|
||||||
|
positional[i] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, positional
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle named parameters (map-like)
|
||||||
|
if paramsMap, ok := params.(map[string]any); ok {
|
||||||
|
modified := make(map[string]any, len(paramsMap))
|
||||||
|
|
||||||
|
for key, value := range paramsMap {
|
||||||
|
if len(key) > 0 && key[0] != ':' {
|
||||||
|
modified[":"+key] = value
|
||||||
|
} else {
|
||||||
|
modified[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return modified, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
// luaSQLQuery executes a SQL query and returns results to Lua
|
// luaSQLQuery executes a SQL query and returns results to Lua
|
||||||
func luaSQLQuery(state *luajit.State) int {
|
func luaSQLQuery(state *luajit.State) int {
|
||||||
// Get database name
|
// Get database name
|
||||||
@ -212,18 +299,71 @@ func luaSQLQuery(state *luajit.State) int {
|
|||||||
}
|
}
|
||||||
query := state.ToString(2)
|
query := state.ToString(2)
|
||||||
|
|
||||||
// Get connection ID (optional for compatibility)
|
// Check if using positional parameters
|
||||||
|
isPositional := false
|
||||||
|
var positionalParams []any
|
||||||
|
|
||||||
|
// Get connection ID (optional)
|
||||||
var connID string
|
var connID string
|
||||||
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
|
|
||||||
connID = state.ToString(4)
|
// Check if we have positional parameters instead of a params table
|
||||||
|
if state.GetTop() >= 3 && !state.IsTable(3) {
|
||||||
|
isPositional = true
|
||||||
|
paramCount := state.GetTop() - 2 // Count all args after db and query
|
||||||
|
|
||||||
|
// Adjust connection ID index if we have positional params
|
||||||
|
if paramCount > 0 {
|
||||||
|
// Last parameter might be connID if it's a string
|
||||||
|
lastIdx := paramCount + 2 // db(1) + query(2) + paramCount
|
||||||
|
if state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) {
|
||||||
|
connID = state.ToString(lastIdx)
|
||||||
|
paramCount-- // Exclude connID from param count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create array for positional parameters
|
||||||
|
positionalParams = make([]any, paramCount)
|
||||||
|
|
||||||
|
// Collect all parameters
|
||||||
|
for i := 0; i < paramCount; i++ {
|
||||||
|
paramIdx := i + 3 // Params start at index 3
|
||||||
|
|
||||||
|
// Convert to appropriate Go value
|
||||||
|
var value any
|
||||||
|
switch state.GetType(paramIdx) {
|
||||||
|
case luajit.TypeNumber:
|
||||||
|
value = state.ToNumber(paramIdx)
|
||||||
|
case luajit.TypeString:
|
||||||
|
value = state.ToString(paramIdx)
|
||||||
|
case luajit.TypeBoolean:
|
||||||
|
value = state.ToBoolean(paramIdx)
|
||||||
|
case luajit.TypeNil:
|
||||||
|
value = nil
|
||||||
|
default:
|
||||||
|
// Try to convert as generic value
|
||||||
|
var err error
|
||||||
|
value, err = state.ToValue(paramIdx)
|
||||||
|
if err != nil {
|
||||||
|
state.PushString(fmt.Sprintf("sqlite.query: failed to convert parameter %d: %s", i+1, err.Error()))
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
positionalParams[i] = value
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// Generate a temporary connection ID
|
// Original named parameter table handling
|
||||||
connID = fmt.Sprintf("temp_%p", &query)
|
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
|
||||||
|
connID = state.ToString(4)
|
||||||
|
} else {
|
||||||
|
// Generate a temporary connection ID
|
||||||
|
connID = fmt.Sprintf("temp_%p", &query)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get parameters (optional)
|
// Get parameters (optional for named parameters)
|
||||||
var params map[string]any
|
var params any
|
||||||
if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
|
if !isPositional && state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
|
||||||
var err error
|
var err error
|
||||||
params, err = state.ToTable(3)
|
params, err = state.ToTable(3)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -240,20 +380,33 @@ func luaSQLQuery(state *luajit.State) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// For temporary connections, defer release
|
// For temporary connections, defer release
|
||||||
if !strings.HasPrefix(connID, "temp_") {
|
if strings.HasPrefix(connID, "temp_") {
|
||||||
defer pool.Put(conn)
|
defer func() {
|
||||||
|
// Release the connection
|
||||||
|
sqliteManager.mu.Lock()
|
||||||
|
delete(sqliteManager.activeConns, connID)
|
||||||
|
sqliteManager.mu.Unlock()
|
||||||
|
|
||||||
// Remove from active connections
|
pool.Put(conn)
|
||||||
sqliteManager.mu.Lock()
|
}()
|
||||||
delete(sqliteManager.activeConns, connID)
|
|
||||||
sqliteManager.mu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute query and collect results
|
// Execute query and collect results
|
||||||
var rows []map[string]any
|
var rows []map[string]any
|
||||||
|
|
||||||
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
|
// Prepare params based on type
|
||||||
Named: prepareNamedParams(params),
|
namedParams, positional := prepareParams(params)
|
||||||
|
|
||||||
|
// If we have direct positional params from function args, use those
|
||||||
|
if isPositional {
|
||||||
|
positional = positionalParams
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count actual placeholders in the query
|
||||||
|
placeholderCount := strings.Count(query, "?")
|
||||||
|
|
||||||
|
// Execute with appropriate parameter type
|
||||||
|
execOpts := &sqlitex.ExecOptions{
|
||||||
ResultFunc: func(stmt *sqlite.Stmt) error {
|
ResultFunc: func(stmt *sqlite.Stmt) error {
|
||||||
row := make(map[string]any)
|
row := make(map[string]any)
|
||||||
columnCount := stmt.ColumnCount()
|
columnCount := stmt.ColumnCount()
|
||||||
@ -285,7 +438,20 @@ func luaSQLQuery(state *luajit.State) int {
|
|||||||
rows = append(rows, rowCopy)
|
rows = append(rows, rowCopy)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
})
|
}
|
||||||
|
|
||||||
|
// Set appropriate parameter type
|
||||||
|
if namedParams != nil {
|
||||||
|
execOpts.Named = namedParams
|
||||||
|
} else if positional != nil {
|
||||||
|
// Make sure we're not passing more positional parameters than placeholders
|
||||||
|
if len(positional) > placeholderCount {
|
||||||
|
positional = positional[:placeholderCount]
|
||||||
|
}
|
||||||
|
execOpts.Args = positional
|
||||||
|
}
|
||||||
|
|
||||||
|
err = sqlitex.Execute(conn, query, execOpts)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("sqlite.query: " + err.Error())
|
state.PushString("sqlite.query: " + err.Error())
|
||||||
@ -323,18 +489,71 @@ func luaSQLExec(state *luajit.State) int {
|
|||||||
}
|
}
|
||||||
query := state.ToString(2)
|
query := state.ToString(2)
|
||||||
|
|
||||||
// Get connection ID (optional for compatibility)
|
// Check if using positional parameters
|
||||||
|
isPositional := false
|
||||||
|
var positionalParams []any
|
||||||
|
|
||||||
|
// Get connection ID (optional)
|
||||||
var connID string
|
var connID string
|
||||||
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
|
|
||||||
connID = state.ToString(4)
|
// Check if we have positional parameters instead of a params table
|
||||||
|
if state.GetTop() >= 3 && !state.IsTable(3) {
|
||||||
|
isPositional = true
|
||||||
|
paramCount := state.GetTop() - 2 // Count all args after db and query
|
||||||
|
|
||||||
|
// Adjust connection ID index if we have positional params
|
||||||
|
if paramCount > 0 {
|
||||||
|
// Last parameter might be connID if it's a string
|
||||||
|
lastIdx := paramCount + 2 // db(1) + query(2) + paramCount
|
||||||
|
if state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) {
|
||||||
|
connID = state.ToString(lastIdx)
|
||||||
|
paramCount-- // Exclude connID from param count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create array for positional parameters
|
||||||
|
positionalParams = make([]any, paramCount)
|
||||||
|
|
||||||
|
// Collect all parameters
|
||||||
|
for i := 0; i < paramCount; i++ {
|
||||||
|
paramIdx := i + 3 // Params start at index 3
|
||||||
|
|
||||||
|
// Convert to appropriate Go value
|
||||||
|
var value any
|
||||||
|
switch state.GetType(paramIdx) {
|
||||||
|
case luajit.TypeNumber:
|
||||||
|
value = state.ToNumber(paramIdx)
|
||||||
|
case luajit.TypeString:
|
||||||
|
value = state.ToString(paramIdx)
|
||||||
|
case luajit.TypeBoolean:
|
||||||
|
value = state.ToBoolean(paramIdx)
|
||||||
|
case luajit.TypeNil:
|
||||||
|
value = nil
|
||||||
|
default:
|
||||||
|
// Try to convert as generic value
|
||||||
|
var err error
|
||||||
|
value, err = state.ToValue(paramIdx)
|
||||||
|
if err != nil {
|
||||||
|
state.PushString(fmt.Sprintf("sqlite.exec: failed to convert parameter %d: %s", i+1, err.Error()))
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
positionalParams[i] = value
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// Generate a temporary connection ID
|
// Original named parameter table handling
|
||||||
connID = fmt.Sprintf("temp_%p", &query)
|
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
|
||||||
|
connID = state.ToString(4)
|
||||||
|
} else {
|
||||||
|
// Generate a temporary connection ID
|
||||||
|
connID = fmt.Sprintf("temp_%p", &query)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get parameters (optional)
|
// Get parameters (optional for named parameters)
|
||||||
var params map[string]any
|
var params any
|
||||||
if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
|
if !isPositional && state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
|
||||||
var err error
|
var err error
|
||||||
params, err = state.ToTable(3)
|
params, err = state.ToTable(3)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -351,26 +570,55 @@ func luaSQLExec(state *luajit.State) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// For temporary connections, defer release
|
// For temporary connections, defer release
|
||||||
if !strings.HasPrefix(connID, "temp_") {
|
if strings.HasPrefix(connID, "temp_") {
|
||||||
defer pool.Put(conn)
|
defer func() {
|
||||||
|
// Release the connection
|
||||||
|
sqliteManager.mu.Lock()
|
||||||
|
delete(sqliteManager.activeConns, connID)
|
||||||
|
sqliteManager.mu.Unlock()
|
||||||
|
|
||||||
// Remove from active connections
|
pool.Put(conn)
|
||||||
sqliteManager.mu.Lock()
|
}()
|
||||||
delete(sqliteManager.activeConns, connID)
|
|
||||||
sqliteManager.mu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute statement
|
// Count actual placeholders in the query
|
||||||
if params != nil {
|
placeholderCount := strings.Count(query, "?")
|
||||||
err = sqlitex.Execute(conn, query, &sqlitex.ExecOptions{
|
|
||||||
Named: prepareNamedParams(params),
|
// Prepare params based on type
|
||||||
})
|
namedParams, positional := prepareParams(params)
|
||||||
|
|
||||||
|
// If we have direct positional params from function args, use those
|
||||||
|
if isPositional {
|
||||||
|
positional = positionalParams
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure we don't pass more parameters than placeholders
|
||||||
|
if positional != nil && len(positional) > placeholderCount {
|
||||||
|
positional = positional[:placeholderCount]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute with appropriate parameter type
|
||||||
|
var execErr error
|
||||||
|
|
||||||
|
if isPositional || positional != nil {
|
||||||
|
// Execute with positional parameters
|
||||||
|
execOpts := &sqlitex.ExecOptions{
|
||||||
|
Args: positional,
|
||||||
|
}
|
||||||
|
execErr = sqlitex.Execute(conn, query, execOpts)
|
||||||
|
} else if namedParams != nil {
|
||||||
|
// Execute with named parameters
|
||||||
|
execOpts := &sqlitex.ExecOptions{
|
||||||
|
Named: namedParams,
|
||||||
|
}
|
||||||
|
execErr = sqlitex.Execute(conn, query, execOpts)
|
||||||
} else {
|
} else {
|
||||||
err = sqlitex.ExecScript(conn, query)
|
// Execute without parameters
|
||||||
|
execErr = sqlitex.ExecScript(conn, query)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if execErr != nil {
|
||||||
state.PushString("sqlite.exec: " + err.Error())
|
state.PushString("sqlite.exec: " + execErr.Error())
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -391,21 +639,3 @@ func RegisterSQLiteFunctions(state *luajit.State) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareNamedParams(params map[string]any) map[string]any {
|
|
||||||
if params == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
modified := make(map[string]any, len(params))
|
|
||||||
|
|
||||||
for key, value := range params {
|
|
||||||
if len(key) > 0 && key[0] != ':' {
|
|
||||||
modified[":"+key] = value
|
|
||||||
} else {
|
|
||||||
modified[key] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return modified
|
|
||||||
}
|
|
||||||
|
@ -1,22 +1,70 @@
|
|||||||
__active_sqlite_connections = {}
|
__active_sqlite_connections = {}
|
||||||
|
|
||||||
|
-- Helper function to handle parameters
|
||||||
|
local function handle_params(params, ...)
|
||||||
|
-- If params is a table, use it for named parameters
|
||||||
|
if type(params) == "table" then
|
||||||
|
return params
|
||||||
|
end
|
||||||
|
|
||||||
|
-- If we have varargs, collect them for positional parameters
|
||||||
|
local args = {...}
|
||||||
|
if #args > 0 or params ~= nil then
|
||||||
|
-- Include the first param in the args
|
||||||
|
table.insert(args, 1, params)
|
||||||
|
return args
|
||||||
|
end
|
||||||
|
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
|
||||||
-- Connection metatable
|
-- Connection metatable
|
||||||
local connection_mt = {
|
local connection_mt = {
|
||||||
__index = {
|
__index = {
|
||||||
-- Execute a query and return results as a table
|
-- Execute a query and return results as a table
|
||||||
query = function(self, query, params)
|
query = function(self, query, params, ...)
|
||||||
if type(query) ~= "string" then
|
if type(query) ~= "string" then
|
||||||
error("connection:query: query must be a string", 2)
|
error("connection:query: query must be a string", 2)
|
||||||
end
|
end
|
||||||
return __sqlite_query(self.db_name, query, params)
|
|
||||||
|
-- Handle params (named or positional)
|
||||||
|
local processed_params = handle_params(params, ...)
|
||||||
|
|
||||||
|
-- Call with appropriate arguments
|
||||||
|
if type(processed_params) == "table" and processed_params[1] ~= nil then
|
||||||
|
-- Positional parameters - insert self.db_name and query at the beginning
|
||||||
|
table.insert(processed_params, 1, query)
|
||||||
|
table.insert(processed_params, 1, self.db_name)
|
||||||
|
-- Add connection ID at the end
|
||||||
|
table.insert(processed_params, self.id)
|
||||||
|
return __sqlite_query(unpack(processed_params))
|
||||||
|
else
|
||||||
|
-- Named parameters or no parameters
|
||||||
|
return __sqlite_query(self.db_name, query, processed_params, self.id)
|
||||||
|
end
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Execute a statement and return affected rows
|
-- Execute a statement and return affected rows
|
||||||
exec = function(self, query, params)
|
exec = function(self, query, params, ...)
|
||||||
if type(query) ~= "string" then
|
if type(query) ~= "string" then
|
||||||
error("connection:exec: query must be a string", 2)
|
error("connection:exec: query must be a string", 2)
|
||||||
end
|
end
|
||||||
return __sqlite_exec(self.db_name, query, params)
|
|
||||||
|
-- Handle params (named or positional)
|
||||||
|
local processed_params = handle_params(params, ...)
|
||||||
|
|
||||||
|
-- Call with appropriate arguments
|
||||||
|
if type(processed_params) == "table" and processed_params[1] ~= nil then
|
||||||
|
-- Positional parameters - insert self.db_name and query at the beginning
|
||||||
|
table.insert(processed_params, 1, query)
|
||||||
|
table.insert(processed_params, 1, self.db_name)
|
||||||
|
-- Add connection ID at the end
|
||||||
|
table.insert(processed_params, self.id)
|
||||||
|
return __sqlite_exec(unpack(processed_params))
|
||||||
|
else
|
||||||
|
-- Named parameters or no parameters
|
||||||
|
return __sqlite_exec(self.db_name, query, processed_params, self.id)
|
||||||
|
end
|
||||||
end,
|
end,
|
||||||
|
|
||||||
-- Create a new table
|
-- Create a new table
|
||||||
@ -121,8 +169,14 @@ local connection_mt = {
|
|||||||
end,
|
end,
|
||||||
|
|
||||||
-- Get one row
|
-- Get one row
|
||||||
get_one = function(self, query, params)
|
get_one = function(self, query, params, ...)
|
||||||
local results = self:query(query, params)
|
-- Handle both named and positional parameters
|
||||||
|
local results
|
||||||
|
if select('#', ...) > 0 then
|
||||||
|
results = self:query(query, params, ...)
|
||||||
|
else
|
||||||
|
results = self:query(query, params)
|
||||||
|
end
|
||||||
return results[1]
|
return results[1]
|
||||||
end,
|
end,
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user