sqlite improvements

This commit is contained in:
Sky Johnson 2025-05-10 14:53:37 -05:00
parent c754877f7d
commit 98b2931d59
2 changed files with 504 additions and 366 deletions

View File

@ -27,20 +27,37 @@ local connection_mt = {
error("connection:query: query must be a string", 2)
end
-- Handle params (named or positional)
local processed_params = handle_params(params, ...)
-- Fast path for no parameters
if params == nil and select('#', ...) == 0 then
return __sqlite_query(self.db_name, query, nil, self.id)
end
-- Call with appropriate arguments
if type(processed_params) == "table" and processed_params[1] ~= nil then
-- Positional parameters - insert self.db_name and query at the beginning
table.insert(processed_params, 1, query)
table.insert(processed_params, 1, self.db_name)
-- Add connection ID at the end
table.insert(processed_params, self.id)
return __sqlite_query(unpack(processed_params))
-- Handle various parameter types efficiently
if type(params) == "table" then
-- If it's an array-like table with numeric keys
if params[1] ~= nil then
-- For positional parameters, we want to include the required prefix args
local args = {self.db_name, query}
-- Append all parameters
for i=1, #params do
args[i+2] = params[i]
end
-- Add connection ID
args[#args+1] = self.id
return __sqlite_query(unpack(args))
else
-- Named parameters or no parameters
return __sqlite_query(self.db_name, query, processed_params, self.id)
-- Named parameters
return __sqlite_query(self.db_name, query, params, self.id)
end
else
-- Variadic parameters, combine with first param
local args = {self.db_name, query, params}
local n = select('#', ...)
for i=1, n do
args[i+3] = select(i, ...)
end
args[#args+1] = self.id
return __sqlite_query(unpack(args))
end
end,
@ -50,20 +67,37 @@ local connection_mt = {
error("connection:exec: query must be a string", 2)
end
-- Handle params (named or positional)
local processed_params = handle_params(params, ...)
-- Fast path for no parameters
if params == nil and select('#', ...) == 0 then
return __sqlite_exec(self.db_name, query, nil, self.id)
end
-- Call with appropriate arguments
if type(processed_params) == "table" and processed_params[1] ~= nil then
-- Positional parameters - insert self.db_name and query at the beginning
table.insert(processed_params, 1, query)
table.insert(processed_params, 1, self.db_name)
-- Add connection ID at the end
table.insert(processed_params, self.id)
return __sqlite_exec(unpack(processed_params))
-- Handle various parameter types efficiently
if type(params) == "table" then
-- If it's an array-like table with numeric keys
if params[1] ~= nil then
-- For positional parameters, we want to include the required prefix args
local args = {self.db_name, query}
-- Append all parameters
for i=1, #params do
args[i+2] = params[i]
end
-- Add connection ID
args[#args+1] = self.id
return __sqlite_exec(unpack(args))
else
-- Named parameters or no parameters
return __sqlite_exec(self.db_name, query, processed_params, self.id)
-- Named parameters
return __sqlite_exec(self.db_name, query, params, self.id)
end
else
-- Variadic parameters, combine with first param
local args = {self.db_name, query, params}
local n = select('#', ...)
for i=1, n do
args[i+3] = select(i, ...)
end
args[#args+1] = self.id
return __sqlite_exec(unpack(args))
end
end,
@ -79,7 +113,7 @@ local connection_mt = {
local index_type, index_def = def:match("^(UNIQUE%s+INDEX:|INDEX:)(.+)")
if index_def then
-- Parse index definition: INDEX:idx_name(col1,col2)
-- Parse index definition
local index_name, columns_str = index_def:match("([%w_]+)%(([^)]+)%)")
if index_name and columns_str then
@ -106,34 +140,32 @@ local connection_mt = {
error("connection:create_table: no columns specified", 2)
end
-- Create the table
local query = string.format("CREATE TABLE IF NOT EXISTS %s (%s)",
table_name, table.concat(columns, ", "))
-- Build combined statement for table and indices
local statements = {}
local result = self:exec(query)
-- Create indices
if #indices > 0 then
self:begin()
-- Add the CREATE TABLE statement
table.insert(statements, string.format(
"CREATE TABLE IF NOT EXISTS %s (%s)",
table_name,
table.concat(columns, ", ")
))
-- Add CREATE INDEX statements
for _, idx in ipairs(indices) do
local unique = idx.unique and "UNIQUE " or ""
local index_query = string.format(
table.insert(statements, string.format(
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
unique,
idx.name,
table_name,
table.concat(idx.columns, ", ")
)
self:exec(index_query)
))
end
self:commit()
end
return result
-- Execute all statements in a single transaction
local combined_sql = table.concat(statements, ";\n")
return self:exec(combined_sql)
end,
-- Insert a row or multiple rows
@ -142,10 +174,44 @@ local connection_mt = {
error("connection:insert: data must be a table", 2)
end
-- Case 1: Named columns with array data
if columns and type(columns) == "table" then
-- Check if we have multiple rows
if #data > 0 and type(data[1]) == "table" then
-- Build a single multi-value INSERT
local placeholders = {}
for _ in ipairs(columns) do
table.insert(placeholders, "?")
local values = {}
local params = {}
local param_index = 1
for i, row in ipairs(data) do
local row_placeholders = {}
for j, _ in ipairs(columns) do
local param_name = "p" .. param_index
table.insert(row_placeholders, ":" .. param_name)
params[param_name] = row[j]
param_index = param_index + 1
end
table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")")
end
local query = string.format(
"INSERT INTO %s (%s) VALUES %s",
table_name,
table.concat(columns, ", "),
table.concat(placeholders, ", ")
)
return self:exec(query, params)
else
-- Single row with defined columns
local placeholders = {}
local params = {}
for i, col in ipairs(columns) do
local param_name = "p" .. i
table.insert(placeholders, ":" .. param_name)
params[param_name] = data[i]
end
local query = string.format(
@ -155,30 +221,11 @@ local connection_mt = {
table.concat(placeholders, ", ")
)
local use_transaction = #data > 1 and type(data[1]) == "table"
if use_transaction then
self:begin()
end
local affected = 0
if #data > 0 and type(data[1]) == "table" then
for _, row in ipairs(data) do
local result = self:exec(query, row)
affected = affected + result
end
else
affected = self:exec(query, data)
end
if use_transaction then
self:commit()
end
return affected
return self:exec(query, params)
end
end
-- Case 2: Object-style single row {col1=val1, col2=val2}
if data[1] == nil and next(data) ~= nil then
local columns = {}
local placeholders = {}
@ -186,8 +233,9 @@ local connection_mt = {
for col, val in pairs(data) do
table.insert(columns, col)
table.insert(placeholders, ":" .. col)
params[":" .. col] = val
local param_name = "p" .. #columns
table.insert(placeholders, ":" .. param_name)
params[param_name] = val
end
local query = string.format(
@ -200,34 +248,74 @@ local connection_mt = {
return self:exec(query, params)
end
-- Case 3: Array of rows without predefined columns
if #data > 0 and type(data[1]) == "table" then
self:begin()
local affected = 0
-- Extract columns from the first row
local first_row = data[1]
local inferred_columns = {}
for _, row in ipairs(data) do
local result = self:insert(table_name, row)
affected = affected + result
-- Determine if first row is array or object
local is_array = first_row[1] ~= nil
if is_array then
-- Cannot infer column names from array
error("connection:insert: column names required for array data", 2)
else
-- Get columns from object keys
for col, _ in pairs(first_row) do
table.insert(inferred_columns, col)
end
self:commit()
return affected
-- Build multi-value INSERT
local placeholders = {}
local params = {}
local param_index = 1
for _, row in ipairs(data) do
local row_placeholders = {}
for _, col in ipairs(inferred_columns) do
local param_name = "p" .. param_index
table.insert(row_placeholders, ":" .. param_name)
params[param_name] = row[col]
param_index = param_index + 1
end
table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")")
end
local query = string.format(
"INSERT INTO %s (%s) VALUES %s",
table_name,
table.concat(inferred_columns, ", "),
table.concat(placeholders, ", ")
)
return self:exec(query, params)
end
end
error("connection:insert: invalid data format", 2)
end,
-- Update rows
update = function(self, table_name, data, where, where_params)
update = function(self, table_name, data, where, where_params, ...)
if type(data) ~= "table" then
error("connection:update: data must be a table", 2)
end
-- Fast path for when there's no data
if next(data) == nil then
return 0
end
local sets = {}
local params = {}
local param_index = 1
for col, val in pairs(data) do
table.insert(sets, col .. " = :" .. col)
params[col] = val
local param_name = "p" .. param_index
table.insert(sets, col .. " = :" .. param_name)
params[param_name] = val
param_index = param_index + 1
end
local query = string.format(
@ -240,8 +328,53 @@ local connection_mt = {
query = query .. " WHERE " .. where
if where_params then
if type(where_params) == "table" then
-- Handle named parameters in WHERE clause
for k, v in pairs(where_params) do
params[k] = v
local param_name
if type(k) == "string" and k:sub(1, 1) == ":" then
param_name = k:sub(2)
else
param_name = "w" .. param_index
-- Replace the placeholder in the WHERE clause
where = where:gsub(":" .. k, ":" .. param_name)
end
params[param_name] = v
param_index = param_index + 1
end
else
-- Handle positional parameters (? placeholders)
local args = {where_params, ...}
local pos = 1
local offset = 0
-- Replace ? with named parameters
while true do
local start_pos, end_pos = where:find("?", pos)
if not start_pos then break end
local param_name = "w" .. param_index
local replacement = ":" .. param_name
where = where:sub(1, start_pos - 1) .. replacement .. where:sub(end_pos + 1)
if args[pos - offset] ~= nil then
params[param_name] = args[pos - offset]
else
params[param_name] = nil
end
param_index = param_index + 1
pos = start_pos + #replacement
offset = offset + 1
end
query = string.format(
"UPDATE %s SET %s WHERE %s",
table_name,
table.concat(sets, ", "),
where
)
end
end
end
@ -260,15 +393,25 @@ local connection_mt = {
return self:exec(query, params)
end,
-- Get one row
-- Get one row efficiently
get_one = function(self, query, params, ...)
-- Handle both named and positional parameters
if type(query) ~= "string" then
error("connection:get_one: query must be a string", 2)
end
-- Add LIMIT 1 to query if not already limited
local limited_query = query
if not query:lower():match("limit%s+%d+") then
limited_query = query .. " LIMIT 1"
end
local results
if select('#', ...) > 0 then
results = self:query(query, params, ...)
results = self:query(limited_query, params, ...)
else
results = self:query(query, params)
results = self:query(limited_query, params)
end
return results[1]
end,

View File

@ -13,14 +13,11 @@ import (
"Moonshark/utils/logger"
"maps"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
)
// SQLiteConnection tracks an active connection
type SQLiteConnection struct {
DbName string
Conn *sqlite.Conn
Pool *sqlitex.Pool
}
@ -54,7 +51,6 @@ func CleanupSQLite() {
sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock()
// Release all connections and close pools
for id, conn := range sqliteManager.activeConns {
if conn.Pool != nil {
conn.Pool.Put(conn.Conn)
@ -79,9 +75,6 @@ func ReleaseActiveConnections(state *luajit.State) {
return
}
sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock()
// Get active connections table from Lua
state.GetGlobal("__active_sqlite_connections")
if !state.IsTable(-1) {
@ -89,6 +82,9 @@ func ReleaseActiveConnections(state *luajit.State) {
return
}
sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock()
// Iterate through active connections
state.PushNil() // Start iteration
for state.Next(-2) {
@ -113,8 +109,8 @@ func ReleaseActiveConnections(state *luajit.State) {
state.SetGlobal("__active_sqlite_connections")
}
// getPool returns a connection pool for the specified database
func getPool(dbName string) (*sqlitex.Pool, error) {
// getConnection returns a connection for the database
func getConnection(dbName, connID string) (*sqlite.Conn, error) {
if sqliteManager == nil {
return nil, errors.New("SQLite not initialized")
}
@ -125,326 +121,325 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
return nil, errors.New("invalid database name")
}
// Check for existing pool with read lock
// Check for existing connection
sqliteManager.mu.RLock()
pool, exists := sqliteManager.pools[dbName]
sqliteManager.mu.RUnlock()
conn, exists := sqliteManager.activeConns[connID]
if exists {
return pool, nil
sqliteManager.mu.RUnlock()
return conn.Conn, nil
}
sqliteManager.mu.RUnlock()
// Create new pool with write lock
// Get or create pool under write lock
sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock()
// Double check if another goroutine created it
if pool, exists = sqliteManager.pools[dbName]; exists {
return pool, nil
// Double-check if a connection was created while waiting for lock
if conn, exists = sqliteManager.activeConns[connID]; exists {
return conn.Conn, nil
}
// Create database file path and pool
// Get or create pool
pool, exists := sqliteManager.pools[dbName]
if !exists {
dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db")
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{})
var err error
pool, err = sqlitex.NewPool(dbPath, sqlitex.PoolOptions{})
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
sqliteManager.pools[dbName] = pool
return pool, nil
}
// getConnection returns a connection from the pool
func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, error) {
// Check for existing connection first
sqliteManager.mu.RLock()
conn, exists := sqliteManager.activeConns[connID]
sqliteManager.mu.RUnlock()
if exists {
return conn.Conn, conn.Pool, nil
}
// Get the pool
pool, err := getPool(dbName)
if err != nil {
return nil, nil, err
}
// Get a connection
dbConn, err := pool.Take(context.Background())
if err != nil {
return nil, nil, fmt.Errorf("failed to get connection from pool: %w", err)
return nil, fmt.Errorf("failed to get connection from pool: %w", err)
}
// Store connection
sqliteManager.mu.Lock()
sqliteManager.activeConns[connID] = &SQLiteConnection{
DbName: dbName,
Conn: dbConn,
Pool: pool,
}
sqliteManager.mu.Unlock()
return dbConn, pool, nil
return dbConn, nil
}
// processParams extracts parameters and connection ID from Lua state
func processParams(state *luajit.State, defaultConnID string) (params any, connID string, isPositional bool, positionalParams []any, err error) {
connID = defaultConnID
// Check if using positional parameters
if state.GetTop() >= 3 && !state.IsTable(3) {
isPositional = true
paramCount := state.GetTop() - 2 // Count all args after db and query
// Check if last param is a connection ID
lastIdx := paramCount + 2 // db(1) + query(2) + paramCount
if paramCount > 0 && state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) {
connID = state.ToString(lastIdx)
paramCount-- // Exclude connID from param count
// releaseConnection returns a connection to its pool
func releaseConnection(connID string) {
if sqliteManager == nil {
return
}
// Create array for positional parameters
positionalParams = make([]any, paramCount)
sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock()
// Collect all parameters
for i := 0; i < paramCount; i++ {
paramIdx := i + 3 // Params start at index 3
switch state.GetType(paramIdx) {
case luajit.TypeNumber:
positionalParams[i] = state.ToNumber(paramIdx)
case luajit.TypeString:
positionalParams[i] = state.ToString(paramIdx)
case luajit.TypeBoolean:
positionalParams[i] = state.ToBoolean(paramIdx)
case luajit.TypeNil:
positionalParams[i] = nil
default:
val, errConv := state.ToValue(paramIdx)
if errConv != nil {
return nil, "", false, nil, fmt.Errorf("failed to convert parameter %d: %w", i+1, errConv)
}
positionalParams[i] = val
}
}
return nil, connID, isPositional, positionalParams, nil
conn, exists := sqliteManager.activeConns[connID]
if !exists {
return
}
// Named parameter handling
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
connID = state.ToString(4)
if conn.Pool != nil {
conn.Pool.Put(conn.Conn)
}
delete(sqliteManager.activeConns, connID)
}
// Get table parameters if present
if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
params, err = state.ToTable(3)
}
return params, connID, isPositional, nil, err
}
// prepareExecOptions prepares SQLite execution options based on parameters
func prepareExecOptions(query string, params any, isPositional bool, positionalParams []any) *sqlitex.ExecOptions {
execOpts := &sqlitex.ExecOptions{}
if params == nil && !isPositional {
return execOpts
}
// Prepare parameters
isArray := false
var namedParams map[string]any
var arrParams []any
// Check for array parameters
if m, ok := params.(map[string]any); ok {
if arr, hasArray := m[""]; hasArray {
isArray = true
if slice, ok := arr.([]any); ok {
arrParams = slice
} else if floatSlice, ok := arr.([]float64); ok {
arrParams = make([]any, len(floatSlice))
for i, v := range floatSlice {
arrParams[i] = v
}
}
} else {
// Process named parameters
namedParams = make(map[string]any, len(m))
for k, v := range m {
if len(k) > 0 && k[0] != ':' {
namedParams[":"+k] = v
} else {
namedParams[k] = v
}
}
}
} else if slice, ok := params.([]any); ok {
isArray = true
arrParams = slice
} else if floatSlice, ok := params.([]float64); ok {
isArray = true
arrParams = make([]any, len(floatSlice))
for i, v := range floatSlice {
arrParams[i] = v
}
}
// Use positional params if explicitly provided
if isPositional {
arrParams = positionalParams
isArray = true
}
// Limit positional params to actual placeholders
if isArray && arrParams != nil {
placeholderCount := strings.Count(query, "?")
if len(arrParams) > placeholderCount {
arrParams = arrParams[:placeholderCount]
}
execOpts.Args = arrParams
} else if namedParams != nil {
execOpts.Named = namedParams
}
return execOpts
}
// sqlOperation handles both query and exec operations
func sqlOperation(state *luajit.State, isQuery bool) int {
operation := "query"
if !isQuery {
operation = "exec"
}
// Get database name
if !state.IsString(1) {
state.PushString(fmt.Sprintf("sqlite.%s: database name must be a string", operation))
// sqlQuery executes a SQL query and returns results
func sqlQuery(state *luajit.State) int {
// Get required parameters
if state.GetTop() < 3 || !state.IsString(1) || !state.IsString(2) {
state.PushString("sqlite.query: requires database name, query, and optional parameters")
return -1
}
dbName := state.ToString(1)
// Get query
if !state.IsString(2) {
state.PushString(fmt.Sprintf("sqlite.%s: query must be a string", operation))
return -1
}
query := state.ToString(2)
// Generate a temporary connection ID if needed
defaultConnID := fmt.Sprintf("temp_%p", &query)
// Process parameters and get connection ID
params, connID, isPositional, positionalParams, err := processParams(state, defaultConnID)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error()))
return -1
}
connID := fmt.Sprintf("temp_%p", &query)
// Get connection
conn, pool, err := getConnection(dbName, connID)
conn, err := getConnection(dbName, connID)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error()))
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
// Create execution options
var execOpts sqlitex.ExecOptions
rows := make([]map[string]any, 0, 16)
// For temporary connections, defer release
if strings.HasPrefix(connID, "temp_") {
defer func() {
sqliteManager.mu.Lock()
delete(sqliteManager.activeConns, connID)
sqliteManager.mu.Unlock()
pool.Put(conn)
}()
defer releaseConnection(connID)
// Set up parameters if provided
if state.GetTop() >= 3 {
if state.IsTable(3) {
params, err := state.ToTable(3)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: invalid parameters: %s", err.Error()))
return -1
}
// Prepare execution options
execOpts := prepareExecOptions(query, params, isPositional, positionalParams)
// Check for array-style params
if arr, ok := params[""]; ok {
if arrParams, ok := arr.([]any); ok {
execOpts.Args = arrParams
} else if floatArr, ok := arr.([]float64); ok {
args := make([]any, len(floatArr))
for i, v := range floatArr {
args[i] = v
}
execOpts.Args = args
}
} else {
// Named parameters
named := make(map[string]any, len(params))
for k, v := range params {
if len(k) > 0 && k[0] != ':' {
named[":"+k] = v
} else {
named[k] = v
}
}
execOpts.Named = named
}
} else {
// Positional parameters
count := state.GetTop() - 2
args := make([]any, count)
for i := 0; i < count; i++ {
idx := i + 3
switch state.GetType(idx) {
case luajit.TypeNumber:
args[i] = state.ToNumber(idx)
case luajit.TypeString:
args[i] = state.ToString(idx)
case luajit.TypeBoolean:
args[i] = state.ToBoolean(idx)
case luajit.TypeNil:
args[i] = nil
default:
val, err := state.ToValue(idx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: invalid parameter %d: %s", i+1, err.Error()))
return -1
}
args[i] = val
}
}
execOpts.Args = args
}
}
// Define rows slice outside the closure
var rows []map[string]any
// For queries, add result function
if isQuery {
// Set up result function
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
row := make(map[string]any)
columnCount := stmt.ColumnCount()
for i := range columnCount {
columnName := stmt.ColumnName(i)
colCount := stmt.ColumnCount()
for i := range colCount {
colName := stmt.ColumnName(i)
switch stmt.ColumnType(i) {
case sqlite.TypeInteger:
row[columnName] = stmt.ColumnInt64(i)
row[colName] = stmt.ColumnInt64(i)
case sqlite.TypeFloat:
row[columnName] = stmt.ColumnFloat(i)
row[colName] = stmt.ColumnFloat(i)
case sqlite.TypeText:
row[columnName] = stmt.ColumnText(i)
row[colName] = stmt.ColumnText(i)
case sqlite.TypeBlob:
blobSize := stmt.ColumnLen(i)
buf := make([]byte, blobSize)
row[columnName] = stmt.ColumnBytes(i, buf)
row[colName] = stmt.ColumnBytes(i, buf)
case sqlite.TypeNull:
row[columnName] = nil
row[colName] = nil
}
}
// Add row copy to results
rowCopy := make(map[string]any, len(row))
maps.Copy(rowCopy, row)
rows = append(rows, rowCopy)
rows = append(rows, row) // No need to copy, this row is used only once
return nil
}
}
// Execute query
var execErr error
if isQuery || execOpts.Args != nil || execOpts.Named != nil {
execErr = sqlitex.Execute(conn, query, execOpts)
} else {
// Use ExecScript for queries without parameters
execErr = sqlitex.ExecScript(conn, query)
}
if execErr != nil {
state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, execErr.Error()))
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
// Return results for query, affected rows for exec
if isQuery {
// Create result table with rows
// Create result table
state.NewTable()
for i, row := range rows {
state.PushNumber(float64(i + 1))
if err := state.PushTable(row); err != nil {
state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error()))
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1
}
state.SetTable(-3)
}
} else {
// Return number of affected rows
state.PushNumber(float64(conn.Changes()))
}
return 1
}
// luaSQLQuery executes a SQL query and returns results to Lua
func luaSQLQuery(state *luajit.State) int {
return sqlOperation(state, true)
// sqlExec executes a SQL statement without returning results
func sqlExec(state *luajit.State) int {
// Get required parameters
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
state.PushString("sqlite.exec: requires database name and query")
return -1
}
// luaSQLExec executes a SQL statement without returning results
func luaSQLExec(state *luajit.State) int {
return sqlOperation(state, false)
dbName := state.ToString(1)
query := state.ToString(2)
connID := fmt.Sprintf("temp_%p", &query)
// Get connection
conn, err := getConnection(dbName, connID)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
// For temporary connections, defer release
defer releaseConnection(connID)
// Check if parameters are provided
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
hasPlaceholders := strings.Contains(query, "?") || strings.Contains(query, ":")
// Fast path for simple queries with no parameters
if !hasParams || !hasPlaceholders {
if err := sqlitex.ExecScript(conn, query); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
state.PushNumber(float64(conn.Changes()))
return 1
}
// Create execution options for parameterized query
var execOpts sqlitex.ExecOptions
// Set up parameters
if state.IsTable(3) {
params, err := state.ToTable(3)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: invalid parameters: %s", err.Error()))
return -1
}
// Check for array-style params
if arr, ok := params[""]; ok {
if arrParams, ok := arr.([]any); ok {
execOpts.Args = arrParams
} else if floatArr, ok := arr.([]float64); ok {
args := make([]any, len(floatArr))
for i, v := range floatArr {
args[i] = v
}
execOpts.Args = args
}
} else {
// Named parameters
named := make(map[string]any, len(params))
for k, v := range params {
if len(k) > 0 && k[0] != ':' {
named[":"+k] = v
} else {
named[k] = v
}
}
execOpts.Named = named
}
} else {
// Positional parameters
count := state.GetTop() - 2
args := make([]any, count)
for i := 0; i < count; i++ {
idx := i + 3
switch state.GetType(idx) {
case luajit.TypeNumber:
args[i] = state.ToNumber(idx)
case luajit.TypeString:
args[i] = state.ToString(idx)
case luajit.TypeBoolean:
args[i] = state.ToBoolean(idx)
case luajit.TypeNil:
args[i] = nil
default:
val, err := state.ToValue(idx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: invalid parameter %d: %s", i+1, err.Error()))
return -1
}
args[i] = val
}
}
execOpts.Args = args
}
// Count the number of placeholders to validate parameter count
if execOpts.Args != nil {
placeholderCount := strings.Count(query, "?")
if len(execOpts.Args) > placeholderCount {
state.PushString(fmt.Sprintf("sqlite.exec: too many parameters provided (%d) for placeholders (%d)",
len(execOpts.Args), placeholderCount))
return -1
}
}
// Execute with parameters
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
// Return affected rows
state.PushNumber(float64(conn.Changes()))
return 1
}
// RegisterSQLiteFunctions registers SQLite functions with the Lua state
func RegisterSQLiteFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__sqlite_query", luaSQLQuery); err != nil {
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
return err
}
return state.RegisterGoFunction("__sqlite_exec", luaSQLExec)
return state.RegisterGoFunction("__sqlite_exec", sqlExec)
}