Refactor sqlite, up password mem default, refactor fs

This commit is contained in:
Sky Johnson 2025-05-03 16:18:18 -05:00
parent 551f311755
commit 30a126909b
4 changed files with 191 additions and 383 deletions

View File

@ -1,29 +1,28 @@
local fs = {} local fs = {}
-- File Operations fs.read = function(path)
fs.read_file = function(path)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.read_file: path must be a string", 2) error("fs.read: path must be a string", 2)
end end
return __fs_read_file(path) return __fs_read_file(path)
end end
fs.write_file = function(path, content) fs.write = function(path, content)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.write_file: path must be a string", 2) error("fs.write: path must be a string", 2)
end end
if type(content) ~= "string" then if type(content) ~= "string" then
error("fs.write_file: content must be a string", 2) error("fs.write: content must be a string", 2)
end end
return __fs_write_file(path, content) return __fs_write_file(path, content)
end end
fs.append_file = function(path, content) fs.append = function(path, content)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.append_file: path must be a string", 2) error("fs.append: path must be a string", 2)
end end
if type(content) ~= "string" then if type(content) ~= "string" then
error("fs.append_file: content must be a string", 2) error("fs.append: content must be a string", 2)
end end
return __fs_append_file(path, content) return __fs_append_file(path, content)
end end
@ -35,16 +34,16 @@ fs.exists = function(path)
return __fs_exists(path) return __fs_exists(path)
end end
fs.remove_file = function(path) fs.remove = function(path)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.remove_file: path must be a string", 2) error("fs.remove: path must be a string", 2)
end end
return __fs_remove_file(path) return __fs_remove_file(path)
end end
fs.get_info = function(path) fs.info = function(path)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.get_info: path must be a string", 2) error("fs.info: path must be a string", 2)
end end
local info = __fs_get_info(path) local info = __fs_get_info(path)
@ -57,24 +56,24 @@ fs.get_info = function(path)
end end
-- Directory Operations -- Directory Operations
fs.make_dir = function(path, mode) fs.mkdir = function(path, mode)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.make_dir: path must be a string", 2) error("fs.mkdir: path must be a string", 2)
end end
mode = mode or 0755 mode = mode or 0755
return __fs_make_dir(path, mode) return __fs_make_dir(path, mode)
end end
fs.list_dir = function(path) fs.ls = function(path)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.list_dir: path must be a string", 2) error("fs.ls: path must be a string", 2)
end end
return __fs_list_dir(path) return __fs_list_dir(path)
end end
fs.remove_dir = function(path, recursive) fs.rmdir = function(path, recursive)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.remove_dir: path must be a string", 2) error("fs.rmdir: path must be a string", 2)
end end
recursive = recursive or false recursive = recursive or false
return __fs_remove_dir(path, recursive) return __fs_remove_dir(path, recursive)

View File

@ -28,7 +28,7 @@ func passwordHash(state *luajit.State) int {
password := state.ToString(1) password := state.ToString(1)
params := &argon2id.Params{ params := &argon2id.Params{
Memory: 64 * 1024, Memory: 128 * 1024,
Iterations: 4, Iterations: 4,
Parallelism: 4, Parallelism: 4,
SaltLength: 16, SaltLength: 16,

View File

@ -585,8 +585,8 @@ local password = {}
-- Hash a password using Argon2id -- Hash a password using Argon2id
-- Options: -- Options:
-- memory: Amount of memory to use in KB (default: 64MB) -- memory: Amount of memory to use in KB (default: 128MB)
-- iterations: Number of iterations (default: 3) -- iterations: Number of iterations (default: 4)
-- parallelism: Number of threads (default: 4) -- parallelism: Number of threads (default: 4)
-- salt_length: Length of salt in bytes (default: 16) -- salt_length: Length of salt in bytes (default: 16)
-- key_length: Length of the derived key in bytes (default: 32) -- key_length: Length of the derived key in bytes (default: 32)

View File

@ -33,7 +33,6 @@ type SQLiteManager struct {
dataDir string dataDir string
} }
// Global manager
var sqliteManager *SQLiteManager var sqliteManager *SQLiteManager
// InitSQLite initializes the SQLite manager // InitSQLite initializes the SQLite manager
@ -55,7 +54,7 @@ func CleanupSQLite() {
sqliteManager.mu.Lock() sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock() defer sqliteManager.mu.Unlock()
// Release all active connections // Release all connections and close pools
for id, conn := range sqliteManager.activeConns { for id, conn := range sqliteManager.activeConns {
if conn.Pool != nil { if conn.Pool != nil {
conn.Pool.Put(conn.Conn) conn.Pool.Put(conn.Conn)
@ -63,7 +62,6 @@ func CleanupSQLite() {
delete(sqliteManager.activeConns, id) delete(sqliteManager.activeConns, id)
} }
// Close all pools
for name, pool := range sqliteManager.pools { for name, pool := range sqliteManager.pools {
if err := pool.Close(); err != nil { if err := pool.Close(); err != nil {
logger.Error("Failed to close database %s: %v", name, err) logger.Error("Failed to close database %s: %v", name, err)
@ -94,13 +92,10 @@ func ReleaseActiveConnections(state *luajit.State) {
// Iterate through active connections // Iterate through active connections
state.PushNil() // Start iteration state.PushNil() // Start iteration
for state.Next(-2) { for state.Next(-2) {
// Stack now has key at -2 and value at -1
if state.IsTable(-1) { if state.IsTable(-1) {
state.GetField(-1, "id") state.GetField(-1, "id")
if state.IsString(-1) { if state.IsString(-1) {
connID := state.ToString(-1) connID := state.ToString(-1)
// Release connection from Go side
if conn, exists := sqliteManager.activeConns[connID]; exists { if conn, exists := sqliteManager.activeConns[connID]; exists {
if conn.Pool != nil { if conn.Pool != nil {
conn.Pool.Put(conn.Conn) conn.Pool.Put(conn.Conn)
@ -130,16 +125,15 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
return nil, errors.New("invalid database name") return nil, errors.New("invalid database name")
} }
// Check for existing pool // Check for existing pool with read lock
sqliteManager.mu.RLock() sqliteManager.mu.RLock()
pool, exists := sqliteManager.pools[dbName] pool, exists := sqliteManager.pools[dbName]
sqliteManager.mu.RUnlock() sqliteManager.mu.RUnlock()
if exists { if exists {
return pool, nil return pool, nil
} }
// Create new pool // Create new pool with write lock
sqliteManager.mu.Lock() sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock() defer sqliteManager.mu.Unlock()
@ -148,12 +142,9 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
return pool, nil return pool, nil
} }
// Create database file path // Create database file path and pool
dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db") dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db")
// Create the pool
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{}) pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err) return nil, fmt.Errorf("failed to open database: %w", err)
} }
@ -179,7 +170,7 @@ func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, e
return nil, nil, err return nil, nil, err
} }
// Get a connection using the newer Take API // Get a connection
dbConn, err := pool.Take(context.Background()) dbConn, err := pool.Take(context.Background())
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to get connection from pool: %w", err) return nil, nil, fmt.Errorf("failed to get connection from pool: %w", err)
@ -197,129 +188,21 @@ func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, e
return dbConn, pool, nil return dbConn, pool, nil
} }
// detectParamType determines if parameters are positional or named // processParams extracts parameters and connection ID from Lua state
func detectParamType(params any) (isArray bool) { func processParams(state *luajit.State, defaultConnID string) (params any, connID string, isPositional bool, positionalParams []any, err error) {
if params == nil { connID = defaultConnID
return false
}
// Check if it's a map[string]any
if paramsMap, ok := params.(map[string]any); ok {
// Check for the empty string key which indicates an array
if array, hasArray := paramsMap[""]; hasArray {
// Verify it's actually an array
if _, isSlice := array.([]any); isSlice {
return true
}
if _, isFloatSlice := array.([]float64); isFloatSlice {
return true
}
}
return false
}
// If it's already a slice type
if _, ok := params.([]any); ok {
return true
}
if _, ok := params.([]float64); ok {
return true
}
return false
}
// prepareParams processes parameters for SQLite queries
func prepareParams(params any) (map[string]any, []any) {
if params == nil {
return nil, nil
}
// Handle positional parameters (array-like)
if detectParamType(params) {
var positional []any
// Extract array from special map format
if paramsMap, ok := params.(map[string]any); ok {
if array, hasArray := paramsMap[""]; hasArray {
if slice, ok := array.([]any); ok {
positional = slice
} else if floatSlice, ok := array.([]float64); ok {
// Convert []float64 to []any
positional = make([]any, len(floatSlice))
for i, v := range floatSlice {
positional[i] = v
}
}
}
} else if slice, ok := params.([]any); ok {
positional = slice
} else if floatSlice, ok := params.([]float64); ok {
// Convert []float64 to []any
positional = make([]any, len(floatSlice))
for i, v := range floatSlice {
positional[i] = v
}
}
return nil, positional
}
// Handle named parameters (map-like)
if paramsMap, ok := params.(map[string]any); ok {
modified := make(map[string]any, len(paramsMap))
for key, value := range paramsMap {
if len(key) > 0 && key[0] != ':' {
modified[":"+key] = value
} else {
modified[key] = value
}
}
return modified, nil
}
return nil, nil
}
// luaSQLQuery executes a SQL query and returns results to Lua
func luaSQLQuery(state *luajit.State) int {
// Get database name
if !state.IsString(1) {
state.PushString("sqlite.query: database name must be a string")
return -1
}
dbName := state.ToString(1)
// Get query
if !state.IsString(2) {
state.PushString("sqlite.query: query must be a string")
return -1
}
query := state.ToString(2)
// Check if using positional parameters // Check if using positional parameters
isPositional := false
var positionalParams []any
// Get connection ID (optional)
var connID string
// Check if we have positional parameters instead of a params table
if state.GetTop() >= 3 && !state.IsTable(3) { if state.GetTop() >= 3 && !state.IsTable(3) {
isPositional = true isPositional = true
paramCount := state.GetTop() - 2 // Count all args after db and query paramCount := state.GetTop() - 2 // Count all args after db and query
// Adjust connection ID index if we have positional params // Check if last param is a connection ID
if paramCount > 0 {
// Last parameter might be connID if it's a string
lastIdx := paramCount + 2 // db(1) + query(2) + paramCount lastIdx := paramCount + 2 // db(1) + query(2) + paramCount
if state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) { if paramCount > 0 && state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) {
connID = state.ToString(lastIdx) connID = state.ToString(lastIdx)
paramCount-- // Exclude connID from param count paramCount-- // Exclude connID from param count
} }
}
// Create array for positional parameters // Create array for positional parameters
positionalParams = make([]any, paramCount) positionalParams = make([]any, paramCount)
@ -327,95 +210,170 @@ func luaSQLQuery(state *luajit.State) int {
// Collect all parameters // Collect all parameters
for i := 0; i < paramCount; i++ { for i := 0; i < paramCount; i++ {
paramIdx := i + 3 // Params start at index 3 paramIdx := i + 3 // Params start at index 3
// Convert to appropriate Go value
var value any
switch state.GetType(paramIdx) { switch state.GetType(paramIdx) {
case luajit.TypeNumber: case luajit.TypeNumber:
value = state.ToNumber(paramIdx) positionalParams[i] = state.ToNumber(paramIdx)
case luajit.TypeString: case luajit.TypeString:
value = state.ToString(paramIdx) positionalParams[i] = state.ToString(paramIdx)
case luajit.TypeBoolean: case luajit.TypeBoolean:
value = state.ToBoolean(paramIdx) positionalParams[i] = state.ToBoolean(paramIdx)
case luajit.TypeNil: case luajit.TypeNil:
value = nil positionalParams[i] = nil
default: default:
// Try to convert as generic value val, errConv := state.ToValue(paramIdx)
var err error if errConv != nil {
value, err = state.ToValue(paramIdx) return nil, "", false, nil, fmt.Errorf("failed to convert parameter %d: %w", i+1, errConv)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: failed to convert parameter %d: %s", i+1, err.Error()))
return -1
} }
positionalParams[i] = val
}
}
return nil, connID, isPositional, positionalParams, nil
} }
positionalParams[i] = value // Named parameter handling
}
} else {
// Original named parameter table handling
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) { if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
connID = state.ToString(4) connID = state.ToString(4)
}
// Get table parameters if present
if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
params, err = state.ToTable(3)
}
return params, connID, isPositional, nil, err
}
// prepareExecOptions prepares SQLite execution options based on parameters
func prepareExecOptions(query string, params any, isPositional bool, positionalParams []any) *sqlitex.ExecOptions {
execOpts := &sqlitex.ExecOptions{}
if params == nil && !isPositional {
return execOpts
}
// Prepare parameters
isArray := false
var namedParams map[string]any
var arrParams []any
// Check for array parameters
if m, ok := params.(map[string]any); ok {
if arr, hasArray := m[""]; hasArray {
isArray = true
if slice, ok := arr.([]any); ok {
arrParams = slice
} else if floatSlice, ok := arr.([]float64); ok {
arrParams = make([]any, len(floatSlice))
for i, v := range floatSlice {
arrParams[i] = v
}
}
} else { } else {
// Generate a temporary connection ID // Process named parameters
connID = fmt.Sprintf("temp_%p", &query) namedParams = make(map[string]any, len(m))
for k, v := range m {
if len(k) > 0 && k[0] != ':' {
namedParams[":"+k] = v
} else {
namedParams[k] = v
}
}
}
} else if slice, ok := params.([]any); ok {
isArray = true
arrParams = slice
} else if floatSlice, ok := params.([]float64); ok {
isArray = true
arrParams = make([]any, len(floatSlice))
for i, v := range floatSlice {
arrParams[i] = v
} }
} }
// Get parameters (optional for named parameters) // Use positional params if explicitly provided
var params any if isPositional {
if !isPositional && state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) { arrParams = positionalParams
var err error isArray = true
params, err = state.ToTable(3) }
if err != nil {
state.PushString("sqlite.query: failed to parse parameters: " + err.Error()) // Limit positional params to actual placeholders
if isArray && arrParams != nil {
placeholderCount := strings.Count(query, "?")
if len(arrParams) > placeholderCount {
arrParams = arrParams[:placeholderCount]
}
execOpts.Args = arrParams
} else if namedParams != nil {
execOpts.Named = namedParams
}
return execOpts
}
// sqlOperation handles both query and exec operations
func sqlOperation(state *luajit.State, isQuery bool) int {
operation := "query"
if !isQuery {
operation = "exec"
}
// Get database name
if !state.IsString(1) {
state.PushString(fmt.Sprintf("sqlite.%s: database name must be a string", operation))
return -1 return -1
} }
dbName := state.ToString(1)
// Get query
if !state.IsString(2) {
state.PushString(fmt.Sprintf("sqlite.%s: query must be a string", operation))
return -1
}
query := state.ToString(2)
// Generate a temporary connection ID if needed
defaultConnID := fmt.Sprintf("temp_%p", &query)
// Process parameters and get connection ID
params, connID, isPositional, positionalParams, err := processParams(state, defaultConnID)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error()))
return -1
} }
// Get connection // Get connection
conn, pool, err := getConnection(dbName, connID) conn, pool, err := getConnection(dbName, connID)
if err != nil { if err != nil {
state.PushString("sqlite.query: " + err.Error()) state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error()))
return -1 return -1
} }
// For temporary connections, defer release // For temporary connections, defer release
if strings.HasPrefix(connID, "temp_") { if strings.HasPrefix(connID, "temp_") {
defer func() { defer func() {
// Release the connection
sqliteManager.mu.Lock() sqliteManager.mu.Lock()
delete(sqliteManager.activeConns, connID) delete(sqliteManager.activeConns, connID)
sqliteManager.mu.Unlock() sqliteManager.mu.Unlock()
pool.Put(conn) pool.Put(conn)
}() }()
} }
// Execute query and collect results // Prepare execution options
execOpts := prepareExecOptions(query, params, isPositional, positionalParams)
// Define rows slice outside the closure
var rows []map[string]any var rows []map[string]any
// Prepare params based on type // For queries, add result function
namedParams, positional := prepareParams(params) if isQuery {
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
// If we have direct positional params from function args, use those
if isPositional {
positional = positionalParams
}
// Count actual placeholders in the query
placeholderCount := strings.Count(query, "?")
// Execute with appropriate parameter type
execOpts := &sqlitex.ExecOptions{
ResultFunc: func(stmt *sqlite.Stmt) error {
row := make(map[string]any) row := make(map[string]any)
columnCount := stmt.ColumnCount() columnCount := stmt.ColumnCount()
for i := range columnCount { for i := range columnCount {
columnName := stmt.ColumnName(i) columnName := stmt.ColumnName(i)
columnType := stmt.ColumnType(i)
switch columnType { switch stmt.ColumnType(i) {
case sqlite.TypeInteger: case sqlite.TypeInteger:
row[columnName] = stmt.ColumnInt64(i) row[columnName] = stmt.ColumnInt64(i)
case sqlite.TypeFloat: case sqlite.TypeFloat:
@ -425,8 +383,7 @@ func luaSQLQuery(state *luajit.State) int {
case sqlite.TypeBlob: case sqlite.TypeBlob:
blobSize := stmt.ColumnLen(i) blobSize := stmt.ColumnLen(i)
buf := make([]byte, blobSize) buf := make([]byte, blobSize)
blob := stmt.ColumnBytes(i, buf) row[columnName] = stmt.ColumnBytes(i, buf)
row[columnName] = blob
case sqlite.TypeNull: case sqlite.TypeNull:
row[columnName] = nil row[columnName] = nil
} }
@ -437,194 +394,51 @@ func luaSQLQuery(state *luajit.State) int {
maps.Copy(rowCopy, row) maps.Copy(rowCopy, row)
rows = append(rows, rowCopy) rows = append(rows, rowCopy)
return nil return nil
}, }
} }
// Set appropriate parameter type // Execute query
if namedParams != nil { var execErr error
execOpts.Named = namedParams if isQuery || execOpts.Args != nil || execOpts.Named != nil {
} else if positional != nil { execErr = sqlitex.Execute(conn, query, execOpts)
// Make sure we're not passing more positional parameters than placeholders } else {
if len(positional) > placeholderCount { // Use ExecScript for queries without parameters
positional = positional[:placeholderCount] execErr = sqlitex.ExecScript(conn, query)
}
execOpts.Args = positional
} }
err = sqlitex.Execute(conn, query, execOpts) if execErr != nil {
state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, execErr.Error()))
if err != nil {
state.PushString("sqlite.query: " + err.Error())
return -1 return -1
} }
// Create result table // Return results for query, affected rows for exec
if isQuery {
// Create result table with rows
state.NewTable() state.NewTable()
// Add results to the table
for i, row := range rows { for i, row := range rows {
state.PushNumber(float64(i + 1)) state.PushNumber(float64(i + 1))
if err := state.PushTable(row); err != nil { if err := state.PushTable(row); err != nil {
state.PushString("sqlite.query: " + err.Error()) state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error()))
return -1 return -1
} }
state.SetTable(-3) state.SetTable(-3)
} }
} else {
// Return number of affected rows
state.PushNumber(float64(conn.Changes()))
}
return 1 return 1
} }
// luaSQLQuery executes a SQL query and returns results to Lua
func luaSQLQuery(state *luajit.State) int {
return sqlOperation(state, true)
}
// luaSQLExec executes a SQL statement without returning results // luaSQLExec executes a SQL statement without returning results
func luaSQLExec(state *luajit.State) int { func luaSQLExec(state *luajit.State) int {
// Get database name and query return sqlOperation(state, false)
if !state.IsString(1) {
state.PushString("sqlite.exec: database name must be a string")
return -1
}
dbName := state.ToString(1)
if !state.IsString(2) {
state.PushString("sqlite.exec: query must be a string")
return -1
}
query := state.ToString(2)
// Check if using positional parameters
isPositional := false
var positionalParams []any
// Get connection ID (optional)
var connID string
// Check if we have positional parameters instead of a params table
if state.GetTop() >= 3 && !state.IsTable(3) {
isPositional = true
paramCount := state.GetTop() - 2 // Count all args after db and query
// Adjust connection ID index if we have positional params
if paramCount > 0 {
// Last parameter might be connID if it's a string
lastIdx := paramCount + 2 // db(1) + query(2) + paramCount
if state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) {
connID = state.ToString(lastIdx)
paramCount-- // Exclude connID from param count
}
}
// Create array for positional parameters
positionalParams = make([]any, paramCount)
// Collect all parameters
for i := 0; i < paramCount; i++ {
paramIdx := i + 3 // Params start at index 3
// Convert to appropriate Go value
var value any
switch state.GetType(paramIdx) {
case luajit.TypeNumber:
value = state.ToNumber(paramIdx)
case luajit.TypeString:
value = state.ToString(paramIdx)
case luajit.TypeBoolean:
value = state.ToBoolean(paramIdx)
case luajit.TypeNil:
value = nil
default:
// Try to convert as generic value
var err error
value, err = state.ToValue(paramIdx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: failed to convert parameter %d: %s", i+1, err.Error()))
return -1
}
}
positionalParams[i] = value
}
} else {
// Original named parameter table handling
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
connID = state.ToString(4)
} else {
// Generate a temporary connection ID
connID = fmt.Sprintf("temp_%p", &query)
}
}
// Get parameters (optional for named parameters)
var params any
if !isPositional && state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
var err error
params, err = state.ToTable(3)
if err != nil {
state.PushString("sqlite.exec: failed to parse parameters: " + err.Error())
return -1
}
}
// Get connection
conn, pool, err := getConnection(dbName, connID)
if err != nil {
state.PushString("sqlite.exec: " + err.Error())
return -1
}
// For temporary connections, defer release
if strings.HasPrefix(connID, "temp_") {
defer func() {
// Release the connection
sqliteManager.mu.Lock()
delete(sqliteManager.activeConns, connID)
sqliteManager.mu.Unlock()
pool.Put(conn)
}()
}
// Count actual placeholders in the query
placeholderCount := strings.Count(query, "?")
// Prepare params based on type
namedParams, positional := prepareParams(params)
// If we have direct positional params from function args, use those
if isPositional {
positional = positionalParams
}
// Ensure we don't pass more parameters than placeholders
if positional != nil && len(positional) > placeholderCount {
positional = positional[:placeholderCount]
}
// Execute with appropriate parameter type
var execErr error
if isPositional || positional != nil {
// Execute with positional parameters
execOpts := &sqlitex.ExecOptions{
Args: positional,
}
execErr = sqlitex.Execute(conn, query, execOpts)
} else if namedParams != nil {
// Execute with named parameters
execOpts := &sqlitex.ExecOptions{
Named: namedParams,
}
execErr = sqlitex.Execute(conn, query, execOpts)
} else {
// Execute without parameters
execErr = sqlitex.ExecScript(conn, query)
}
if execErr != nil {
state.PushString("sqlite.exec: " + execErr.Error())
return -1
}
// Return number of affected rows
state.PushNumber(float64(conn.Changes()))
return 1
} }
// RegisterSQLiteFunctions registers SQLite functions with the Lua state // RegisterSQLiteFunctions registers SQLite functions with the Lua state
@ -632,10 +446,5 @@ func RegisterSQLiteFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__sqlite_query", luaSQLQuery); err != nil { if err := state.RegisterGoFunction("__sqlite_query", luaSQLQuery); err != nil {
return err return err
} }
return state.RegisterGoFunction("__sqlite_exec", luaSQLExec)
if err := state.RegisterGoFunction("__sqlite_exec", luaSQLExec); err != nil {
return err
}
return nil
} }