Refactor sqlite, up password mem default, refactor fs
This commit is contained in:
parent
551f311755
commit
30a126909b
@ -1,29 +1,28 @@
|
|||||||
local fs = {}
|
local fs = {}
|
||||||
|
|
||||||
-- File Operations
|
fs.read = function(path)
|
||||||
fs.read_file = function(path)
|
|
||||||
if type(path) ~= "string" then
|
if type(path) ~= "string" then
|
||||||
error("fs.read_file: path must be a string", 2)
|
error("fs.read: path must be a string", 2)
|
||||||
end
|
end
|
||||||
return __fs_read_file(path)
|
return __fs_read_file(path)
|
||||||
end
|
end
|
||||||
|
|
||||||
fs.write_file = function(path, content)
|
fs.write = function(path, content)
|
||||||
if type(path) ~= "string" then
|
if type(path) ~= "string" then
|
||||||
error("fs.write_file: path must be a string", 2)
|
error("fs.write: path must be a string", 2)
|
||||||
end
|
end
|
||||||
if type(content) ~= "string" then
|
if type(content) ~= "string" then
|
||||||
error("fs.write_file: content must be a string", 2)
|
error("fs.write: content must be a string", 2)
|
||||||
end
|
end
|
||||||
return __fs_write_file(path, content)
|
return __fs_write_file(path, content)
|
||||||
end
|
end
|
||||||
|
|
||||||
fs.append_file = function(path, content)
|
fs.append = function(path, content)
|
||||||
if type(path) ~= "string" then
|
if type(path) ~= "string" then
|
||||||
error("fs.append_file: path must be a string", 2)
|
error("fs.append: path must be a string", 2)
|
||||||
end
|
end
|
||||||
if type(content) ~= "string" then
|
if type(content) ~= "string" then
|
||||||
error("fs.append_file: content must be a string", 2)
|
error("fs.append: content must be a string", 2)
|
||||||
end
|
end
|
||||||
return __fs_append_file(path, content)
|
return __fs_append_file(path, content)
|
||||||
end
|
end
|
||||||
@ -35,16 +34,16 @@ fs.exists = function(path)
|
|||||||
return __fs_exists(path)
|
return __fs_exists(path)
|
||||||
end
|
end
|
||||||
|
|
||||||
fs.remove_file = function(path)
|
fs.remove = function(path)
|
||||||
if type(path) ~= "string" then
|
if type(path) ~= "string" then
|
||||||
error("fs.remove_file: path must be a string", 2)
|
error("fs.remove: path must be a string", 2)
|
||||||
end
|
end
|
||||||
return __fs_remove_file(path)
|
return __fs_remove_file(path)
|
||||||
end
|
end
|
||||||
|
|
||||||
fs.get_info = function(path)
|
fs.info = function(path)
|
||||||
if type(path) ~= "string" then
|
if type(path) ~= "string" then
|
||||||
error("fs.get_info: path must be a string", 2)
|
error("fs.info: path must be a string", 2)
|
||||||
end
|
end
|
||||||
local info = __fs_get_info(path)
|
local info = __fs_get_info(path)
|
||||||
|
|
||||||
@ -57,24 +56,24 @@ fs.get_info = function(path)
|
|||||||
end
|
end
|
||||||
|
|
||||||
-- Directory Operations
|
-- Directory Operations
|
||||||
fs.make_dir = function(path, mode)
|
fs.mkdir = function(path, mode)
|
||||||
if type(path) ~= "string" then
|
if type(path) ~= "string" then
|
||||||
error("fs.make_dir: path must be a string", 2)
|
error("fs.mkdir: path must be a string", 2)
|
||||||
end
|
end
|
||||||
mode = mode or 0755
|
mode = mode or 0755
|
||||||
return __fs_make_dir(path, mode)
|
return __fs_make_dir(path, mode)
|
||||||
end
|
end
|
||||||
|
|
||||||
fs.list_dir = function(path)
|
fs.ls = function(path)
|
||||||
if type(path) ~= "string" then
|
if type(path) ~= "string" then
|
||||||
error("fs.list_dir: path must be a string", 2)
|
error("fs.ls: path must be a string", 2)
|
||||||
end
|
end
|
||||||
return __fs_list_dir(path)
|
return __fs_list_dir(path)
|
||||||
end
|
end
|
||||||
|
|
||||||
fs.remove_dir = function(path, recursive)
|
fs.rmdir = function(path, recursive)
|
||||||
if type(path) ~= "string" then
|
if type(path) ~= "string" then
|
||||||
error("fs.remove_dir: path must be a string", 2)
|
error("fs.rmdir: path must be a string", 2)
|
||||||
end
|
end
|
||||||
recursive = recursive or false
|
recursive = recursive or false
|
||||||
return __fs_remove_dir(path, recursive)
|
return __fs_remove_dir(path, recursive)
|
||||||
|
@ -28,7 +28,7 @@ func passwordHash(state *luajit.State) int {
|
|||||||
password := state.ToString(1)
|
password := state.ToString(1)
|
||||||
|
|
||||||
params := &argon2id.Params{
|
params := &argon2id.Params{
|
||||||
Memory: 64 * 1024,
|
Memory: 128 * 1024,
|
||||||
Iterations: 4,
|
Iterations: 4,
|
||||||
Parallelism: 4,
|
Parallelism: 4,
|
||||||
SaltLength: 16,
|
SaltLength: 16,
|
||||||
|
@ -585,8 +585,8 @@ local password = {}
|
|||||||
|
|
||||||
-- Hash a password using Argon2id
|
-- Hash a password using Argon2id
|
||||||
-- Options:
|
-- Options:
|
||||||
-- memory: Amount of memory to use in KB (default: 64MB)
|
-- memory: Amount of memory to use in KB (default: 128MB)
|
||||||
-- iterations: Number of iterations (default: 3)
|
-- iterations: Number of iterations (default: 4)
|
||||||
-- parallelism: Number of threads (default: 4)
|
-- parallelism: Number of threads (default: 4)
|
||||||
-- salt_length: Length of salt in bytes (default: 16)
|
-- salt_length: Length of salt in bytes (default: 16)
|
||||||
-- key_length: Length of the derived key in bytes (default: 32)
|
-- key_length: Length of the derived key in bytes (default: 32)
|
||||||
|
@ -33,7 +33,6 @@ type SQLiteManager struct {
|
|||||||
dataDir string
|
dataDir string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Global manager
|
|
||||||
var sqliteManager *SQLiteManager
|
var sqliteManager *SQLiteManager
|
||||||
|
|
||||||
// InitSQLite initializes the SQLite manager
|
// InitSQLite initializes the SQLite manager
|
||||||
@ -55,7 +54,7 @@ func CleanupSQLite() {
|
|||||||
sqliteManager.mu.Lock()
|
sqliteManager.mu.Lock()
|
||||||
defer sqliteManager.mu.Unlock()
|
defer sqliteManager.mu.Unlock()
|
||||||
|
|
||||||
// Release all active connections
|
// Release all connections and close pools
|
||||||
for id, conn := range sqliteManager.activeConns {
|
for id, conn := range sqliteManager.activeConns {
|
||||||
if conn.Pool != nil {
|
if conn.Pool != nil {
|
||||||
conn.Pool.Put(conn.Conn)
|
conn.Pool.Put(conn.Conn)
|
||||||
@ -63,7 +62,6 @@ func CleanupSQLite() {
|
|||||||
delete(sqliteManager.activeConns, id)
|
delete(sqliteManager.activeConns, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close all pools
|
|
||||||
for name, pool := range sqliteManager.pools {
|
for name, pool := range sqliteManager.pools {
|
||||||
if err := pool.Close(); err != nil {
|
if err := pool.Close(); err != nil {
|
||||||
logger.Error("Failed to close database %s: %v", name, err)
|
logger.Error("Failed to close database %s: %v", name, err)
|
||||||
@ -94,13 +92,10 @@ func ReleaseActiveConnections(state *luajit.State) {
|
|||||||
// Iterate through active connections
|
// Iterate through active connections
|
||||||
state.PushNil() // Start iteration
|
state.PushNil() // Start iteration
|
||||||
for state.Next(-2) {
|
for state.Next(-2) {
|
||||||
// Stack now has key at -2 and value at -1
|
|
||||||
if state.IsTable(-1) {
|
if state.IsTable(-1) {
|
||||||
state.GetField(-1, "id")
|
state.GetField(-1, "id")
|
||||||
if state.IsString(-1) {
|
if state.IsString(-1) {
|
||||||
connID := state.ToString(-1)
|
connID := state.ToString(-1)
|
||||||
|
|
||||||
// Release connection from Go side
|
|
||||||
if conn, exists := sqliteManager.activeConns[connID]; exists {
|
if conn, exists := sqliteManager.activeConns[connID]; exists {
|
||||||
if conn.Pool != nil {
|
if conn.Pool != nil {
|
||||||
conn.Pool.Put(conn.Conn)
|
conn.Pool.Put(conn.Conn)
|
||||||
@ -130,16 +125,15 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
|
|||||||
return nil, errors.New("invalid database name")
|
return nil, errors.New("invalid database name")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for existing pool
|
// Check for existing pool with read lock
|
||||||
sqliteManager.mu.RLock()
|
sqliteManager.mu.RLock()
|
||||||
pool, exists := sqliteManager.pools[dbName]
|
pool, exists := sqliteManager.pools[dbName]
|
||||||
sqliteManager.mu.RUnlock()
|
sqliteManager.mu.RUnlock()
|
||||||
|
|
||||||
if exists {
|
if exists {
|
||||||
return pool, nil
|
return pool, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new pool
|
// Create new pool with write lock
|
||||||
sqliteManager.mu.Lock()
|
sqliteManager.mu.Lock()
|
||||||
defer sqliteManager.mu.Unlock()
|
defer sqliteManager.mu.Unlock()
|
||||||
|
|
||||||
@ -148,12 +142,9 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
|
|||||||
return pool, nil
|
return pool, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create database file path
|
// Create database file path and pool
|
||||||
dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db")
|
dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db")
|
||||||
|
|
||||||
// Create the pool
|
|
||||||
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{})
|
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||||
}
|
}
|
||||||
@ -179,7 +170,7 @@ func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, e
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get a connection using the newer Take API
|
// Get a connection
|
||||||
dbConn, err := pool.Take(context.Background())
|
dbConn, err := pool.Take(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("failed to get connection from pool: %w", err)
|
return nil, nil, fmt.Errorf("failed to get connection from pool: %w", err)
|
||||||
@ -197,128 +188,20 @@ func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, e
|
|||||||
return dbConn, pool, nil
|
return dbConn, pool, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// detectParamType determines if parameters are positional or named
|
// processParams extracts parameters and connection ID from Lua state
|
||||||
func detectParamType(params any) (isArray bool) {
|
func processParams(state *luajit.State, defaultConnID string) (params any, connID string, isPositional bool, positionalParams []any, err error) {
|
||||||
if params == nil {
|
connID = defaultConnID
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if it's a map[string]any
|
|
||||||
if paramsMap, ok := params.(map[string]any); ok {
|
|
||||||
// Check for the empty string key which indicates an array
|
|
||||||
if array, hasArray := paramsMap[""]; hasArray {
|
|
||||||
// Verify it's actually an array
|
|
||||||
if _, isSlice := array.([]any); isSlice {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if _, isFloatSlice := array.([]float64); isFloatSlice {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// If it's already a slice type
|
|
||||||
if _, ok := params.([]any); ok {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if _, ok := params.([]float64); ok {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepareParams processes parameters for SQLite queries
|
|
||||||
func prepareParams(params any) (map[string]any, []any) {
|
|
||||||
if params == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle positional parameters (array-like)
|
|
||||||
if detectParamType(params) {
|
|
||||||
var positional []any
|
|
||||||
|
|
||||||
// Extract array from special map format
|
|
||||||
if paramsMap, ok := params.(map[string]any); ok {
|
|
||||||
if array, hasArray := paramsMap[""]; hasArray {
|
|
||||||
if slice, ok := array.([]any); ok {
|
|
||||||
positional = slice
|
|
||||||
} else if floatSlice, ok := array.([]float64); ok {
|
|
||||||
// Convert []float64 to []any
|
|
||||||
positional = make([]any, len(floatSlice))
|
|
||||||
for i, v := range floatSlice {
|
|
||||||
positional[i] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if slice, ok := params.([]any); ok {
|
|
||||||
positional = slice
|
|
||||||
} else if floatSlice, ok := params.([]float64); ok {
|
|
||||||
// Convert []float64 to []any
|
|
||||||
positional = make([]any, len(floatSlice))
|
|
||||||
for i, v := range floatSlice {
|
|
||||||
positional[i] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, positional
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle named parameters (map-like)
|
|
||||||
if paramsMap, ok := params.(map[string]any); ok {
|
|
||||||
modified := make(map[string]any, len(paramsMap))
|
|
||||||
|
|
||||||
for key, value := range paramsMap {
|
|
||||||
if len(key) > 0 && key[0] != ':' {
|
|
||||||
modified[":"+key] = value
|
|
||||||
} else {
|
|
||||||
modified[key] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return modified, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// luaSQLQuery executes a SQL query and returns results to Lua
|
|
||||||
func luaSQLQuery(state *luajit.State) int {
|
|
||||||
// Get database name
|
|
||||||
if !state.IsString(1) {
|
|
||||||
state.PushString("sqlite.query: database name must be a string")
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
dbName := state.ToString(1)
|
|
||||||
|
|
||||||
// Get query
|
|
||||||
if !state.IsString(2) {
|
|
||||||
state.PushString("sqlite.query: query must be a string")
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
query := state.ToString(2)
|
|
||||||
|
|
||||||
// Check if using positional parameters
|
// Check if using positional parameters
|
||||||
isPositional := false
|
|
||||||
var positionalParams []any
|
|
||||||
|
|
||||||
// Get connection ID (optional)
|
|
||||||
var connID string
|
|
||||||
|
|
||||||
// Check if we have positional parameters instead of a params table
|
|
||||||
if state.GetTop() >= 3 && !state.IsTable(3) {
|
if state.GetTop() >= 3 && !state.IsTable(3) {
|
||||||
isPositional = true
|
isPositional = true
|
||||||
paramCount := state.GetTop() - 2 // Count all args after db and query
|
paramCount := state.GetTop() - 2 // Count all args after db and query
|
||||||
|
|
||||||
// Adjust connection ID index if we have positional params
|
// Check if last param is a connection ID
|
||||||
if paramCount > 0 {
|
lastIdx := paramCount + 2 // db(1) + query(2) + paramCount
|
||||||
// Last parameter might be connID if it's a string
|
if paramCount > 0 && state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) {
|
||||||
lastIdx := paramCount + 2 // db(1) + query(2) + paramCount
|
connID = state.ToString(lastIdx)
|
||||||
if state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) {
|
paramCount-- // Exclude connID from param count
|
||||||
connID = state.ToString(lastIdx)
|
|
||||||
paramCount-- // Exclude connID from param count
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create array for positional parameters
|
// Create array for positional parameters
|
||||||
@ -327,95 +210,170 @@ func luaSQLQuery(state *luajit.State) int {
|
|||||||
// Collect all parameters
|
// Collect all parameters
|
||||||
for i := 0; i < paramCount; i++ {
|
for i := 0; i < paramCount; i++ {
|
||||||
paramIdx := i + 3 // Params start at index 3
|
paramIdx := i + 3 // Params start at index 3
|
||||||
|
|
||||||
// Convert to appropriate Go value
|
|
||||||
var value any
|
|
||||||
switch state.GetType(paramIdx) {
|
switch state.GetType(paramIdx) {
|
||||||
case luajit.TypeNumber:
|
case luajit.TypeNumber:
|
||||||
value = state.ToNumber(paramIdx)
|
positionalParams[i] = state.ToNumber(paramIdx)
|
||||||
case luajit.TypeString:
|
case luajit.TypeString:
|
||||||
value = state.ToString(paramIdx)
|
positionalParams[i] = state.ToString(paramIdx)
|
||||||
case luajit.TypeBoolean:
|
case luajit.TypeBoolean:
|
||||||
value = state.ToBoolean(paramIdx)
|
positionalParams[i] = state.ToBoolean(paramIdx)
|
||||||
case luajit.TypeNil:
|
case luajit.TypeNil:
|
||||||
value = nil
|
positionalParams[i] = nil
|
||||||
default:
|
default:
|
||||||
// Try to convert as generic value
|
val, errConv := state.ToValue(paramIdx)
|
||||||
var err error
|
if errConv != nil {
|
||||||
value, err = state.ToValue(paramIdx)
|
return nil, "", false, nil, fmt.Errorf("failed to convert parameter %d: %w", i+1, errConv)
|
||||||
if err != nil {
|
}
|
||||||
state.PushString(fmt.Sprintf("sqlite.query: failed to convert parameter %d: %s", i+1, err.Error()))
|
positionalParams[i] = val
|
||||||
return -1
|
}
|
||||||
|
}
|
||||||
|
return nil, connID, isPositional, positionalParams, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Named parameter handling
|
||||||
|
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
|
||||||
|
connID = state.ToString(4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get table parameters if present
|
||||||
|
if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
|
||||||
|
params, err = state.ToTable(3)
|
||||||
|
}
|
||||||
|
|
||||||
|
return params, connID, isPositional, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepareExecOptions prepares SQLite execution options based on parameters
|
||||||
|
func prepareExecOptions(query string, params any, isPositional bool, positionalParams []any) *sqlitex.ExecOptions {
|
||||||
|
execOpts := &sqlitex.ExecOptions{}
|
||||||
|
|
||||||
|
if params == nil && !isPositional {
|
||||||
|
return execOpts
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare parameters
|
||||||
|
isArray := false
|
||||||
|
var namedParams map[string]any
|
||||||
|
var arrParams []any
|
||||||
|
|
||||||
|
// Check for array parameters
|
||||||
|
if m, ok := params.(map[string]any); ok {
|
||||||
|
if arr, hasArray := m[""]; hasArray {
|
||||||
|
isArray = true
|
||||||
|
if slice, ok := arr.([]any); ok {
|
||||||
|
arrParams = slice
|
||||||
|
} else if floatSlice, ok := arr.([]float64); ok {
|
||||||
|
arrParams = make([]any, len(floatSlice))
|
||||||
|
for i, v := range floatSlice {
|
||||||
|
arrParams[i] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
positionalParams[i] = value
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Original named parameter table handling
|
|
||||||
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
|
|
||||||
connID = state.ToString(4)
|
|
||||||
} else {
|
} else {
|
||||||
// Generate a temporary connection ID
|
// Process named parameters
|
||||||
connID = fmt.Sprintf("temp_%p", &query)
|
namedParams = make(map[string]any, len(m))
|
||||||
|
for k, v := range m {
|
||||||
|
if len(k) > 0 && k[0] != ':' {
|
||||||
|
namedParams[":"+k] = v
|
||||||
|
} else {
|
||||||
|
namedParams[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if slice, ok := params.([]any); ok {
|
||||||
|
isArray = true
|
||||||
|
arrParams = slice
|
||||||
|
} else if floatSlice, ok := params.([]float64); ok {
|
||||||
|
isArray = true
|
||||||
|
arrParams = make([]any, len(floatSlice))
|
||||||
|
for i, v := range floatSlice {
|
||||||
|
arrParams[i] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get parameters (optional for named parameters)
|
// Use positional params if explicitly provided
|
||||||
var params any
|
if isPositional {
|
||||||
if !isPositional && state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
|
arrParams = positionalParams
|
||||||
var err error
|
isArray = true
|
||||||
params, err = state.ToTable(3)
|
}
|
||||||
if err != nil {
|
|
||||||
state.PushString("sqlite.query: failed to parse parameters: " + err.Error())
|
// Limit positional params to actual placeholders
|
||||||
return -1
|
if isArray && arrParams != nil {
|
||||||
|
placeholderCount := strings.Count(query, "?")
|
||||||
|
if len(arrParams) > placeholderCount {
|
||||||
|
arrParams = arrParams[:placeholderCount]
|
||||||
}
|
}
|
||||||
|
execOpts.Args = arrParams
|
||||||
|
} else if namedParams != nil {
|
||||||
|
execOpts.Named = namedParams
|
||||||
|
}
|
||||||
|
|
||||||
|
return execOpts
|
||||||
|
}
|
||||||
|
|
||||||
|
// sqlOperation handles both query and exec operations
|
||||||
|
func sqlOperation(state *luajit.State, isQuery bool) int {
|
||||||
|
operation := "query"
|
||||||
|
if !isQuery {
|
||||||
|
operation = "exec"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get database name
|
||||||
|
if !state.IsString(1) {
|
||||||
|
state.PushString(fmt.Sprintf("sqlite.%s: database name must be a string", operation))
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
dbName := state.ToString(1)
|
||||||
|
|
||||||
|
// Get query
|
||||||
|
if !state.IsString(2) {
|
||||||
|
state.PushString(fmt.Sprintf("sqlite.%s: query must be a string", operation))
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
query := state.ToString(2)
|
||||||
|
|
||||||
|
// Generate a temporary connection ID if needed
|
||||||
|
defaultConnID := fmt.Sprintf("temp_%p", &query)
|
||||||
|
|
||||||
|
// Process parameters and get connection ID
|
||||||
|
params, connID, isPositional, positionalParams, err := processParams(state, defaultConnID)
|
||||||
|
if err != nil {
|
||||||
|
state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error()))
|
||||||
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get connection
|
// Get connection
|
||||||
conn, pool, err := getConnection(dbName, connID)
|
conn, pool, err := getConnection(dbName, connID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.PushString("sqlite.query: " + err.Error())
|
state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error()))
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
// For temporary connections, defer release
|
// For temporary connections, defer release
|
||||||
if strings.HasPrefix(connID, "temp_") {
|
if strings.HasPrefix(connID, "temp_") {
|
||||||
defer func() {
|
defer func() {
|
||||||
// Release the connection
|
|
||||||
sqliteManager.mu.Lock()
|
sqliteManager.mu.Lock()
|
||||||
delete(sqliteManager.activeConns, connID)
|
delete(sqliteManager.activeConns, connID)
|
||||||
sqliteManager.mu.Unlock()
|
sqliteManager.mu.Unlock()
|
||||||
|
|
||||||
pool.Put(conn)
|
pool.Put(conn)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute query and collect results
|
// Prepare execution options
|
||||||
|
execOpts := prepareExecOptions(query, params, isPositional, positionalParams)
|
||||||
|
|
||||||
|
// Define rows slice outside the closure
|
||||||
var rows []map[string]any
|
var rows []map[string]any
|
||||||
|
|
||||||
// Prepare params based on type
|
// For queries, add result function
|
||||||
namedParams, positional := prepareParams(params)
|
if isQuery {
|
||||||
|
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
||||||
// If we have direct positional params from function args, use those
|
|
||||||
if isPositional {
|
|
||||||
positional = positionalParams
|
|
||||||
}
|
|
||||||
|
|
||||||
// Count actual placeholders in the query
|
|
||||||
placeholderCount := strings.Count(query, "?")
|
|
||||||
|
|
||||||
// Execute with appropriate parameter type
|
|
||||||
execOpts := &sqlitex.ExecOptions{
|
|
||||||
ResultFunc: func(stmt *sqlite.Stmt) error {
|
|
||||||
row := make(map[string]any)
|
row := make(map[string]any)
|
||||||
columnCount := stmt.ColumnCount()
|
columnCount := stmt.ColumnCount()
|
||||||
|
|
||||||
for i := range columnCount {
|
for i := range columnCount {
|
||||||
columnName := stmt.ColumnName(i)
|
columnName := stmt.ColumnName(i)
|
||||||
columnType := stmt.ColumnType(i)
|
|
||||||
|
|
||||||
switch columnType {
|
switch stmt.ColumnType(i) {
|
||||||
case sqlite.TypeInteger:
|
case sqlite.TypeInteger:
|
||||||
row[columnName] = stmt.ColumnInt64(i)
|
row[columnName] = stmt.ColumnInt64(i)
|
||||||
case sqlite.TypeFloat:
|
case sqlite.TypeFloat:
|
||||||
@ -425,8 +383,7 @@ func luaSQLQuery(state *luajit.State) int {
|
|||||||
case sqlite.TypeBlob:
|
case sqlite.TypeBlob:
|
||||||
blobSize := stmt.ColumnLen(i)
|
blobSize := stmt.ColumnLen(i)
|
||||||
buf := make([]byte, blobSize)
|
buf := make([]byte, blobSize)
|
||||||
blob := stmt.ColumnBytes(i, buf)
|
row[columnName] = stmt.ColumnBytes(i, buf)
|
||||||
row[columnName] = blob
|
|
||||||
case sqlite.TypeNull:
|
case sqlite.TypeNull:
|
||||||
row[columnName] = nil
|
row[columnName] = nil
|
||||||
}
|
}
|
||||||
@ -437,194 +394,51 @@ func luaSQLQuery(state *luajit.State) int {
|
|||||||
maps.Copy(rowCopy, row)
|
maps.Copy(rowCopy, row)
|
||||||
rows = append(rows, rowCopy)
|
rows = append(rows, rowCopy)
|
||||||
return nil
|
return nil
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set appropriate parameter type
|
|
||||||
if namedParams != nil {
|
|
||||||
execOpts.Named = namedParams
|
|
||||||
} else if positional != nil {
|
|
||||||
// Make sure we're not passing more positional parameters than placeholders
|
|
||||||
if len(positional) > placeholderCount {
|
|
||||||
positional = positional[:placeholderCount]
|
|
||||||
}
|
}
|
||||||
execOpts.Args = positional
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = sqlitex.Execute(conn, query, execOpts)
|
// Execute query
|
||||||
|
var execErr error
|
||||||
|
if isQuery || execOpts.Args != nil || execOpts.Named != nil {
|
||||||
|
execErr = sqlitex.Execute(conn, query, execOpts)
|
||||||
|
} else {
|
||||||
|
// Use ExecScript for queries without parameters
|
||||||
|
execErr = sqlitex.ExecScript(conn, query)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if execErr != nil {
|
||||||
state.PushString("sqlite.query: " + err.Error())
|
state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, execErr.Error()))
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create result table
|
// Return results for query, affected rows for exec
|
||||||
state.NewTable()
|
if isQuery {
|
||||||
|
// Create result table with rows
|
||||||
// Add results to the table
|
state.NewTable()
|
||||||
for i, row := range rows {
|
for i, row := range rows {
|
||||||
state.PushNumber(float64(i + 1))
|
state.PushNumber(float64(i + 1))
|
||||||
if err := state.PushTable(row); err != nil {
|
if err := state.PushTable(row); err != nil {
|
||||||
state.PushString("sqlite.query: " + err.Error())
|
state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error()))
|
||||||
return -1
|
return -1
|
||||||
|
}
|
||||||
|
state.SetTable(-3)
|
||||||
}
|
}
|
||||||
state.SetTable(-3)
|
} else {
|
||||||
|
// Return number of affected rows
|
||||||
|
state.PushNumber(float64(conn.Changes()))
|
||||||
}
|
}
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// luaSQLQuery executes a SQL query and returns results to Lua
|
||||||
|
func luaSQLQuery(state *luajit.State) int {
|
||||||
|
return sqlOperation(state, true)
|
||||||
|
}
|
||||||
|
|
||||||
// luaSQLExec executes a SQL statement without returning results
|
// luaSQLExec executes a SQL statement without returning results
|
||||||
func luaSQLExec(state *luajit.State) int {
|
func luaSQLExec(state *luajit.State) int {
|
||||||
// Get database name and query
|
return sqlOperation(state, false)
|
||||||
if !state.IsString(1) {
|
|
||||||
state.PushString("sqlite.exec: database name must be a string")
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
dbName := state.ToString(1)
|
|
||||||
|
|
||||||
if !state.IsString(2) {
|
|
||||||
state.PushString("sqlite.exec: query must be a string")
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
query := state.ToString(2)
|
|
||||||
|
|
||||||
// Check if using positional parameters
|
|
||||||
isPositional := false
|
|
||||||
var positionalParams []any
|
|
||||||
|
|
||||||
// Get connection ID (optional)
|
|
||||||
var connID string
|
|
||||||
|
|
||||||
// Check if we have positional parameters instead of a params table
|
|
||||||
if state.GetTop() >= 3 && !state.IsTable(3) {
|
|
||||||
isPositional = true
|
|
||||||
paramCount := state.GetTop() - 2 // Count all args after db and query
|
|
||||||
|
|
||||||
// Adjust connection ID index if we have positional params
|
|
||||||
if paramCount > 0 {
|
|
||||||
// Last parameter might be connID if it's a string
|
|
||||||
lastIdx := paramCount + 2 // db(1) + query(2) + paramCount
|
|
||||||
if state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) {
|
|
||||||
connID = state.ToString(lastIdx)
|
|
||||||
paramCount-- // Exclude connID from param count
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create array for positional parameters
|
|
||||||
positionalParams = make([]any, paramCount)
|
|
||||||
|
|
||||||
// Collect all parameters
|
|
||||||
for i := 0; i < paramCount; i++ {
|
|
||||||
paramIdx := i + 3 // Params start at index 3
|
|
||||||
|
|
||||||
// Convert to appropriate Go value
|
|
||||||
var value any
|
|
||||||
switch state.GetType(paramIdx) {
|
|
||||||
case luajit.TypeNumber:
|
|
||||||
value = state.ToNumber(paramIdx)
|
|
||||||
case luajit.TypeString:
|
|
||||||
value = state.ToString(paramIdx)
|
|
||||||
case luajit.TypeBoolean:
|
|
||||||
value = state.ToBoolean(paramIdx)
|
|
||||||
case luajit.TypeNil:
|
|
||||||
value = nil
|
|
||||||
default:
|
|
||||||
// Try to convert as generic value
|
|
||||||
var err error
|
|
||||||
value, err = state.ToValue(paramIdx)
|
|
||||||
if err != nil {
|
|
||||||
state.PushString(fmt.Sprintf("sqlite.exec: failed to convert parameter %d: %s", i+1, err.Error()))
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
positionalParams[i] = value
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Original named parameter table handling
|
|
||||||
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
|
|
||||||
connID = state.ToString(4)
|
|
||||||
} else {
|
|
||||||
// Generate a temporary connection ID
|
|
||||||
connID = fmt.Sprintf("temp_%p", &query)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get parameters (optional for named parameters)
|
|
||||||
var params any
|
|
||||||
if !isPositional && state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
|
|
||||||
var err error
|
|
||||||
params, err = state.ToTable(3)
|
|
||||||
if err != nil {
|
|
||||||
state.PushString("sqlite.exec: failed to parse parameters: " + err.Error())
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get connection
|
|
||||||
conn, pool, err := getConnection(dbName, connID)
|
|
||||||
if err != nil {
|
|
||||||
state.PushString("sqlite.exec: " + err.Error())
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
// For temporary connections, defer release
|
|
||||||
if strings.HasPrefix(connID, "temp_") {
|
|
||||||
defer func() {
|
|
||||||
// Release the connection
|
|
||||||
sqliteManager.mu.Lock()
|
|
||||||
delete(sqliteManager.activeConns, connID)
|
|
||||||
sqliteManager.mu.Unlock()
|
|
||||||
|
|
||||||
pool.Put(conn)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Count actual placeholders in the query
|
|
||||||
placeholderCount := strings.Count(query, "?")
|
|
||||||
|
|
||||||
// Prepare params based on type
|
|
||||||
namedParams, positional := prepareParams(params)
|
|
||||||
|
|
||||||
// If we have direct positional params from function args, use those
|
|
||||||
if isPositional {
|
|
||||||
positional = positionalParams
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure we don't pass more parameters than placeholders
|
|
||||||
if positional != nil && len(positional) > placeholderCount {
|
|
||||||
positional = positional[:placeholderCount]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute with appropriate parameter type
|
|
||||||
var execErr error
|
|
||||||
|
|
||||||
if isPositional || positional != nil {
|
|
||||||
// Execute with positional parameters
|
|
||||||
execOpts := &sqlitex.ExecOptions{
|
|
||||||
Args: positional,
|
|
||||||
}
|
|
||||||
execErr = sqlitex.Execute(conn, query, execOpts)
|
|
||||||
} else if namedParams != nil {
|
|
||||||
// Execute with named parameters
|
|
||||||
execOpts := &sqlitex.ExecOptions{
|
|
||||||
Named: namedParams,
|
|
||||||
}
|
|
||||||
execErr = sqlitex.Execute(conn, query, execOpts)
|
|
||||||
} else {
|
|
||||||
// Execute without parameters
|
|
||||||
execErr = sqlitex.ExecScript(conn, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
if execErr != nil {
|
|
||||||
state.PushString("sqlite.exec: " + execErr.Error())
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return number of affected rows
|
|
||||||
state.PushNumber(float64(conn.Changes()))
|
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterSQLiteFunctions registers SQLite functions with the Lua state
|
// RegisterSQLiteFunctions registers SQLite functions with the Lua state
|
||||||
@ -632,10 +446,5 @@ func RegisterSQLiteFunctions(state *luajit.State) error {
|
|||||||
if err := state.RegisterGoFunction("__sqlite_query", luaSQLQuery); err != nil {
|
if err := state.RegisterGoFunction("__sqlite_query", luaSQLQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
return state.RegisterGoFunction("__sqlite_exec", luaSQLExec)
|
||||||
if err := state.RegisterGoFunction("__sqlite_exec", luaSQLExec); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user