sqlite improvements

This commit is contained in:
Sky Johnson 2025-05-10 14:53:37 -05:00
parent c754877f7d
commit 98b2931d59
2 changed files with 504 additions and 366 deletions

View File

@ -27,20 +27,37 @@ local connection_mt = {
error("connection:query: query must be a string", 2) error("connection:query: query must be a string", 2)
end end
-- Handle params (named or positional) -- Fast path for no parameters
local processed_params = handle_params(params, ...) if params == nil and select('#', ...) == 0 then
return __sqlite_query(self.db_name, query, nil, self.id)
end
-- Call with appropriate arguments -- Handle various parameter types efficiently
if type(processed_params) == "table" and processed_params[1] ~= nil then if type(params) == "table" then
-- Positional parameters - insert self.db_name and query at the beginning -- If it's an array-like table with numeric keys
table.insert(processed_params, 1, query) if params[1] ~= nil then
table.insert(processed_params, 1, self.db_name) -- For positional parameters, we want to include the required prefix args
-- Add connection ID at the end local args = {self.db_name, query}
table.insert(processed_params, self.id) -- Append all parameters
return __sqlite_query(unpack(processed_params)) for i=1, #params do
args[i+2] = params[i]
end
-- Add connection ID
args[#args+1] = self.id
return __sqlite_query(unpack(args))
else else
-- Named parameters or no parameters -- Named parameters
return __sqlite_query(self.db_name, query, processed_params, self.id) return __sqlite_query(self.db_name, query, params, self.id)
end
else
-- Variadic parameters, combine with first param
local args = {self.db_name, query, params}
local n = select('#', ...)
for i=1, n do
args[i+3] = select(i, ...)
end
args[#args+1] = self.id
return __sqlite_query(unpack(args))
end end
end, end,
@ -50,20 +67,37 @@ local connection_mt = {
error("connection:exec: query must be a string", 2) error("connection:exec: query must be a string", 2)
end end
-- Handle params (named or positional) -- Fast path for no parameters
local processed_params = handle_params(params, ...) if params == nil and select('#', ...) == 0 then
return __sqlite_exec(self.db_name, query, nil, self.id)
end
-- Call with appropriate arguments -- Handle various parameter types efficiently
if type(processed_params) == "table" and processed_params[1] ~= nil then if type(params) == "table" then
-- Positional parameters - insert self.db_name and query at the beginning -- If it's an array-like table with numeric keys
table.insert(processed_params, 1, query) if params[1] ~= nil then
table.insert(processed_params, 1, self.db_name) -- For positional parameters, we want to include the required prefix args
-- Add connection ID at the end local args = {self.db_name, query}
table.insert(processed_params, self.id) -- Append all parameters
return __sqlite_exec(unpack(processed_params)) for i=1, #params do
args[i+2] = params[i]
end
-- Add connection ID
args[#args+1] = self.id
return __sqlite_exec(unpack(args))
else else
-- Named parameters or no parameters -- Named parameters
return __sqlite_exec(self.db_name, query, processed_params, self.id) return __sqlite_exec(self.db_name, query, params, self.id)
end
else
-- Variadic parameters, combine with first param
local args = {self.db_name, query, params}
local n = select('#', ...)
for i=1, n do
args[i+3] = select(i, ...)
end
args[#args+1] = self.id
return __sqlite_exec(unpack(args))
end end
end, end,
@ -79,7 +113,7 @@ local connection_mt = {
local index_type, index_def = def:match("^(UNIQUE%s+INDEX:|INDEX:)(.+)") local index_type, index_def = def:match("^(UNIQUE%s+INDEX:|INDEX:)(.+)")
if index_def then if index_def then
-- Parse index definition: INDEX:idx_name(col1,col2) -- Parse index definition
local index_name, columns_str = index_def:match("([%w_]+)%(([^)]+)%)") local index_name, columns_str = index_def:match("([%w_]+)%(([^)]+)%)")
if index_name and columns_str then if index_name and columns_str then
@ -106,34 +140,32 @@ local connection_mt = {
error("connection:create_table: no columns specified", 2) error("connection:create_table: no columns specified", 2)
end end
-- Create the table -- Build combined statement for table and indices
local query = string.format("CREATE TABLE IF NOT EXISTS %s (%s)", local statements = {}
table_name, table.concat(columns, ", "))
local result = self:exec(query) -- Add the CREATE TABLE statement
table.insert(statements, string.format(
-- Create indices "CREATE TABLE IF NOT EXISTS %s (%s)",
if #indices > 0 then table_name,
self:begin() table.concat(columns, ", ")
))
-- Add CREATE INDEX statements
for _, idx in ipairs(indices) do for _, idx in ipairs(indices) do
local unique = idx.unique and "UNIQUE " or "" local unique = idx.unique and "UNIQUE " or ""
local index_query = string.format( table.insert(statements, string.format(
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)", "CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
unique, unique,
idx.name, idx.name,
table_name, table_name,
table.concat(idx.columns, ", ") table.concat(idx.columns, ", ")
) ))
self:exec(index_query)
end end
self:commit() -- Execute all statements in a single transaction
end local combined_sql = table.concat(statements, ";\n")
return self:exec(combined_sql)
return result
end, end,
-- Insert a row or multiple rows -- Insert a row or multiple rows
@ -142,10 +174,44 @@ local connection_mt = {
error("connection:insert: data must be a table", 2) error("connection:insert: data must be a table", 2)
end end
-- Case 1: Named columns with array data
if columns and type(columns) == "table" then if columns and type(columns) == "table" then
-- Check if we have multiple rows
if #data > 0 and type(data[1]) == "table" then
-- Build a single multi-value INSERT
local placeholders = {} local placeholders = {}
for _ in ipairs(columns) do local values = {}
table.insert(placeholders, "?") local params = {}
local param_index = 1
for i, row in ipairs(data) do
local row_placeholders = {}
for j, _ in ipairs(columns) do
local param_name = "p" .. param_index
table.insert(row_placeholders, ":" .. param_name)
params[param_name] = row[j]
param_index = param_index + 1
end
table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")")
end
local query = string.format(
"INSERT INTO %s (%s) VALUES %s",
table_name,
table.concat(columns, ", "),
table.concat(placeholders, ", ")
)
return self:exec(query, params)
else
-- Single row with defined columns
local placeholders = {}
local params = {}
for i, col in ipairs(columns) do
local param_name = "p" .. i
table.insert(placeholders, ":" .. param_name)
params[param_name] = data[i]
end end
local query = string.format( local query = string.format(
@ -155,30 +221,11 @@ local connection_mt = {
table.concat(placeholders, ", ") table.concat(placeholders, ", ")
) )
local use_transaction = #data > 1 and type(data[1]) == "table" return self:exec(query, params)
end
if use_transaction then
self:begin()
end
local affected = 0
if #data > 0 and type(data[1]) == "table" then
for _, row in ipairs(data) do
local result = self:exec(query, row)
affected = affected + result
end
else
affected = self:exec(query, data)
end
if use_transaction then
self:commit()
end
return affected
end end
-- Case 2: Object-style single row {col1=val1, col2=val2}
if data[1] == nil and next(data) ~= nil then if data[1] == nil and next(data) ~= nil then
local columns = {} local columns = {}
local placeholders = {} local placeholders = {}
@ -186,8 +233,9 @@ local connection_mt = {
for col, val in pairs(data) do for col, val in pairs(data) do
table.insert(columns, col) table.insert(columns, col)
table.insert(placeholders, ":" .. col) local param_name = "p" .. #columns
params[":" .. col] = val table.insert(placeholders, ":" .. param_name)
params[param_name] = val
end end
local query = string.format( local query = string.format(
@ -200,34 +248,74 @@ local connection_mt = {
return self:exec(query, params) return self:exec(query, params)
end end
-- Case 3: Array of rows without predefined columns
if #data > 0 and type(data[1]) == "table" then if #data > 0 and type(data[1]) == "table" then
self:begin() -- Extract columns from the first row
local affected = 0 local first_row = data[1]
local inferred_columns = {}
for _, row in ipairs(data) do -- Determine if first row is array or object
local result = self:insert(table_name, row) local is_array = first_row[1] ~= nil
affected = affected + result
if is_array then
-- Cannot infer column names from array
error("connection:insert: column names required for array data", 2)
else
-- Get columns from object keys
for col, _ in pairs(first_row) do
table.insert(inferred_columns, col)
end end
self:commit() -- Build multi-value INSERT
return affected local placeholders = {}
local params = {}
local param_index = 1
for _, row in ipairs(data) do
local row_placeholders = {}
for _, col in ipairs(inferred_columns) do
local param_name = "p" .. param_index
table.insert(row_placeholders, ":" .. param_name)
params[param_name] = row[col]
param_index = param_index + 1
end
table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")")
end
local query = string.format(
"INSERT INTO %s (%s) VALUES %s",
table_name,
table.concat(inferred_columns, ", "),
table.concat(placeholders, ", ")
)
return self:exec(query, params)
end
end end
error("connection:insert: invalid data format", 2) error("connection:insert: invalid data format", 2)
end, end,
-- Update rows -- Update rows
update = function(self, table_name, data, where, where_params) update = function(self, table_name, data, where, where_params, ...)
if type(data) ~= "table" then if type(data) ~= "table" then
error("connection:update: data must be a table", 2) error("connection:update: data must be a table", 2)
end end
-- Fast path for when there's no data
if next(data) == nil then
return 0
end
local sets = {} local sets = {}
local params = {} local params = {}
local param_index = 1
for col, val in pairs(data) do for col, val in pairs(data) do
table.insert(sets, col .. " = :" .. col) local param_name = "p" .. param_index
params[col] = val table.insert(sets, col .. " = :" .. param_name)
params[param_name] = val
param_index = param_index + 1
end end
local query = string.format( local query = string.format(
@ -240,8 +328,53 @@ local connection_mt = {
query = query .. " WHERE " .. where query = query .. " WHERE " .. where
if where_params then if where_params then
if type(where_params) == "table" then
-- Handle named parameters in WHERE clause
for k, v in pairs(where_params) do for k, v in pairs(where_params) do
params[k] = v local param_name
if type(k) == "string" and k:sub(1, 1) == ":" then
param_name = k:sub(2)
else
param_name = "w" .. param_index
-- Replace the placeholder in the WHERE clause
where = where:gsub(":" .. k, ":" .. param_name)
end
params[param_name] = v
param_index = param_index + 1
end
else
-- Handle positional parameters (? placeholders)
local args = {where_params, ...}
local pos = 1
local offset = 0
-- Replace ? with named parameters
while true do
local start_pos, end_pos = where:find("?", pos)
if not start_pos then break end
local param_name = "w" .. param_index
local replacement = ":" .. param_name
where = where:sub(1, start_pos - 1) .. replacement .. where:sub(end_pos + 1)
if args[pos - offset] ~= nil then
params[param_name] = args[pos - offset]
else
params[param_name] = nil
end
param_index = param_index + 1
pos = start_pos + #replacement
offset = offset + 1
end
query = string.format(
"UPDATE %s SET %s WHERE %s",
table_name,
table.concat(sets, ", "),
where
)
end end
end end
end end
@ -260,15 +393,25 @@ local connection_mt = {
return self:exec(query, params) return self:exec(query, params)
end, end,
-- Get one row -- Get one row efficiently
get_one = function(self, query, params, ...) get_one = function(self, query, params, ...)
-- Handle both named and positional parameters if type(query) ~= "string" then
error("connection:get_one: query must be a string", 2)
end
-- Add LIMIT 1 to query if not already limited
local limited_query = query
if not query:lower():match("limit%s+%d+") then
limited_query = query .. " LIMIT 1"
end
local results local results
if select('#', ...) > 0 then if select('#', ...) > 0 then
results = self:query(query, params, ...) results = self:query(limited_query, params, ...)
else else
results = self:query(query, params) results = self:query(limited_query, params)
end end
return results[1] return results[1]
end, end,

View File

@ -13,14 +13,11 @@ import (
"Moonshark/utils/logger" "Moonshark/utils/logger"
"maps"
luajit "git.sharkk.net/Sky/LuaJIT-to-Go" luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
) )
// SQLiteConnection tracks an active connection // SQLiteConnection tracks an active connection
type SQLiteConnection struct { type SQLiteConnection struct {
DbName string
Conn *sqlite.Conn Conn *sqlite.Conn
Pool *sqlitex.Pool Pool *sqlitex.Pool
} }
@ -54,7 +51,6 @@ func CleanupSQLite() {
sqliteManager.mu.Lock() sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock() defer sqliteManager.mu.Unlock()
// Release all connections and close pools
for id, conn := range sqliteManager.activeConns { for id, conn := range sqliteManager.activeConns {
if conn.Pool != nil { if conn.Pool != nil {
conn.Pool.Put(conn.Conn) conn.Pool.Put(conn.Conn)
@ -79,9 +75,6 @@ func ReleaseActiveConnections(state *luajit.State) {
return return
} }
sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock()
// Get active connections table from Lua // Get active connections table from Lua
state.GetGlobal("__active_sqlite_connections") state.GetGlobal("__active_sqlite_connections")
if !state.IsTable(-1) { if !state.IsTable(-1) {
@ -89,6 +82,9 @@ func ReleaseActiveConnections(state *luajit.State) {
return return
} }
sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock()
// Iterate through active connections // Iterate through active connections
state.PushNil() // Start iteration state.PushNil() // Start iteration
for state.Next(-2) { for state.Next(-2) {
@ -113,8 +109,8 @@ func ReleaseActiveConnections(state *luajit.State) {
state.SetGlobal("__active_sqlite_connections") state.SetGlobal("__active_sqlite_connections")
} }
// getPool returns a connection pool for the specified database // getConnection returns a connection for the database
func getPool(dbName string) (*sqlitex.Pool, error) { func getConnection(dbName, connID string) (*sqlite.Conn, error) {
if sqliteManager == nil { if sqliteManager == nil {
return nil, errors.New("SQLite not initialized") return nil, errors.New("SQLite not initialized")
} }
@ -125,326 +121,325 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
return nil, errors.New("invalid database name") return nil, errors.New("invalid database name")
} }
// Check for existing pool with read lock // Check for existing connection
sqliteManager.mu.RLock() sqliteManager.mu.RLock()
pool, exists := sqliteManager.pools[dbName] conn, exists := sqliteManager.activeConns[connID]
sqliteManager.mu.RUnlock()
if exists { if exists {
return pool, nil sqliteManager.mu.RUnlock()
return conn.Conn, nil
} }
sqliteManager.mu.RUnlock()
// Create new pool with write lock // Get or create pool under write lock
sqliteManager.mu.Lock() sqliteManager.mu.Lock()
defer sqliteManager.mu.Unlock() defer sqliteManager.mu.Unlock()
// Double check if another goroutine created it // Double-check if a connection was created while waiting for lock
if pool, exists = sqliteManager.pools[dbName]; exists { if conn, exists = sqliteManager.activeConns[connID]; exists {
return pool, nil return conn.Conn, nil
} }
// Create database file path and pool // Get or create pool
pool, exists := sqliteManager.pools[dbName]
if !exists {
dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db") dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db")
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{}) var err error
pool, err = sqlitex.NewPool(dbPath, sqlitex.PoolOptions{})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err) return nil, fmt.Errorf("failed to open database: %w", err)
} }
sqliteManager.pools[dbName] = pool sqliteManager.pools[dbName] = pool
return pool, nil
}
// getConnection returns a connection from the pool
func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, error) {
// Check for existing connection first
sqliteManager.mu.RLock()
conn, exists := sqliteManager.activeConns[connID]
sqliteManager.mu.RUnlock()
if exists {
return conn.Conn, conn.Pool, nil
}
// Get the pool
pool, err := getPool(dbName)
if err != nil {
return nil, nil, err
} }
// Get a connection // Get a connection
dbConn, err := pool.Take(context.Background()) dbConn, err := pool.Take(context.Background())
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to get connection from pool: %w", err) return nil, fmt.Errorf("failed to get connection from pool: %w", err)
} }
// Store connection // Store connection
sqliteManager.mu.Lock()
sqliteManager.activeConns[connID] = &SQLiteConnection{ sqliteManager.activeConns[connID] = &SQLiteConnection{
DbName: dbName,
Conn: dbConn, Conn: dbConn,
Pool: pool, Pool: pool,
} }
sqliteManager.mu.Unlock()
return dbConn, pool, nil return dbConn, nil
} }
// processParams extracts parameters and connection ID from Lua state // releaseConnection returns a connection to its pool
func processParams(state *luajit.State, defaultConnID string) (params any, connID string, isPositional bool, positionalParams []any, err error) { func releaseConnection(connID string) {
connID = defaultConnID if sqliteManager == nil {
return
// Check if using positional parameters
if state.GetTop() >= 3 && !state.IsTable(3) {
isPositional = true
paramCount := state.GetTop() - 2 // Count all args after db and query
// Check if last param is a connection ID
lastIdx := paramCount + 2 // db(1) + query(2) + paramCount
if paramCount > 0 && state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) {
connID = state.ToString(lastIdx)
paramCount-- // Exclude connID from param count
} }
// Create array for positional parameters sqliteManager.mu.Lock()
positionalParams = make([]any, paramCount) defer sqliteManager.mu.Unlock()
// Collect all parameters conn, exists := sqliteManager.activeConns[connID]
for i := 0; i < paramCount; i++ { if !exists {
paramIdx := i + 3 // Params start at index 3 return
switch state.GetType(paramIdx) {
case luajit.TypeNumber:
positionalParams[i] = state.ToNumber(paramIdx)
case luajit.TypeString:
positionalParams[i] = state.ToString(paramIdx)
case luajit.TypeBoolean:
positionalParams[i] = state.ToBoolean(paramIdx)
case luajit.TypeNil:
positionalParams[i] = nil
default:
val, errConv := state.ToValue(paramIdx)
if errConv != nil {
return nil, "", false, nil, fmt.Errorf("failed to convert parameter %d: %w", i+1, errConv)
}
positionalParams[i] = val
}
}
return nil, connID, isPositional, positionalParams, nil
} }
// Named parameter handling if conn.Pool != nil {
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) { conn.Pool.Put(conn.Conn)
connID = state.ToString(4) }
delete(sqliteManager.activeConns, connID)
} }
// Get table parameters if present // sqlQuery executes a SQL query and returns results
if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) { func sqlQuery(state *luajit.State) int {
params, err = state.ToTable(3) // Get required parameters
} if state.GetTop() < 3 || !state.IsString(1) || !state.IsString(2) {
state.PushString("sqlite.query: requires database name, query, and optional parameters")
return params, connID, isPositional, nil, err
}
// prepareExecOptions prepares SQLite execution options based on parameters
func prepareExecOptions(query string, params any, isPositional bool, positionalParams []any) *sqlitex.ExecOptions {
execOpts := &sqlitex.ExecOptions{}
if params == nil && !isPositional {
return execOpts
}
// Prepare parameters
isArray := false
var namedParams map[string]any
var arrParams []any
// Check for array parameters
if m, ok := params.(map[string]any); ok {
if arr, hasArray := m[""]; hasArray {
isArray = true
if slice, ok := arr.([]any); ok {
arrParams = slice
} else if floatSlice, ok := arr.([]float64); ok {
arrParams = make([]any, len(floatSlice))
for i, v := range floatSlice {
arrParams[i] = v
}
}
} else {
// Process named parameters
namedParams = make(map[string]any, len(m))
for k, v := range m {
if len(k) > 0 && k[0] != ':' {
namedParams[":"+k] = v
} else {
namedParams[k] = v
}
}
}
} else if slice, ok := params.([]any); ok {
isArray = true
arrParams = slice
} else if floatSlice, ok := params.([]float64); ok {
isArray = true
arrParams = make([]any, len(floatSlice))
for i, v := range floatSlice {
arrParams[i] = v
}
}
// Use positional params if explicitly provided
if isPositional {
arrParams = positionalParams
isArray = true
}
// Limit positional params to actual placeholders
if isArray && arrParams != nil {
placeholderCount := strings.Count(query, "?")
if len(arrParams) > placeholderCount {
arrParams = arrParams[:placeholderCount]
}
execOpts.Args = arrParams
} else if namedParams != nil {
execOpts.Named = namedParams
}
return execOpts
}
// sqlOperation handles both query and exec operations
func sqlOperation(state *luajit.State, isQuery bool) int {
operation := "query"
if !isQuery {
operation = "exec"
}
// Get database name
if !state.IsString(1) {
state.PushString(fmt.Sprintf("sqlite.%s: database name must be a string", operation))
return -1 return -1
} }
dbName := state.ToString(1) dbName := state.ToString(1)
// Get query
if !state.IsString(2) {
state.PushString(fmt.Sprintf("sqlite.%s: query must be a string", operation))
return -1
}
query := state.ToString(2) query := state.ToString(2)
connID := fmt.Sprintf("temp_%p", &query)
// Generate a temporary connection ID if needed
defaultConnID := fmt.Sprintf("temp_%p", &query)
// Process parameters and get connection ID
params, connID, isPositional, positionalParams, err := processParams(state, defaultConnID)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error()))
return -1
}
// Get connection // Get connection
conn, pool, err := getConnection(dbName, connID) conn, err := getConnection(dbName, connID)
if err != nil { if err != nil {
state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error())) state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1 return -1
} }
// Create execution options
var execOpts sqlitex.ExecOptions
rows := make([]map[string]any, 0, 16)
// For temporary connections, defer release // For temporary connections, defer release
if strings.HasPrefix(connID, "temp_") { defer releaseConnection(connID)
defer func() {
sqliteManager.mu.Lock() // Set up parameters if provided
delete(sqliteManager.activeConns, connID) if state.GetTop() >= 3 {
sqliteManager.mu.Unlock() if state.IsTable(3) {
pool.Put(conn) params, err := state.ToTable(3)
}() if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: invalid parameters: %s", err.Error()))
return -1
} }
// Prepare execution options // Check for array-style params
execOpts := prepareExecOptions(query, params, isPositional, positionalParams) if arr, ok := params[""]; ok {
if arrParams, ok := arr.([]any); ok {
execOpts.Args = arrParams
} else if floatArr, ok := arr.([]float64); ok {
args := make([]any, len(floatArr))
for i, v := range floatArr {
args[i] = v
}
execOpts.Args = args
}
} else {
// Named parameters
named := make(map[string]any, len(params))
for k, v := range params {
if len(k) > 0 && k[0] != ':' {
named[":"+k] = v
} else {
named[k] = v
}
}
execOpts.Named = named
}
} else {
// Positional parameters
count := state.GetTop() - 2
args := make([]any, count)
for i := 0; i < count; i++ {
idx := i + 3
switch state.GetType(idx) {
case luajit.TypeNumber:
args[i] = state.ToNumber(idx)
case luajit.TypeString:
args[i] = state.ToString(idx)
case luajit.TypeBoolean:
args[i] = state.ToBoolean(idx)
case luajit.TypeNil:
args[i] = nil
default:
val, err := state.ToValue(idx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.query: invalid parameter %d: %s", i+1, err.Error()))
return -1
}
args[i] = val
}
}
execOpts.Args = args
}
}
// Define rows slice outside the closure // Set up result function
var rows []map[string]any
// For queries, add result function
if isQuery {
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error { execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
row := make(map[string]any) row := make(map[string]any)
columnCount := stmt.ColumnCount() colCount := stmt.ColumnCount()
for i := range columnCount {
columnName := stmt.ColumnName(i)
for i := range colCount {
colName := stmt.ColumnName(i)
switch stmt.ColumnType(i) { switch stmt.ColumnType(i) {
case sqlite.TypeInteger: case sqlite.TypeInteger:
row[columnName] = stmt.ColumnInt64(i) row[colName] = stmt.ColumnInt64(i)
case sqlite.TypeFloat: case sqlite.TypeFloat:
row[columnName] = stmt.ColumnFloat(i) row[colName] = stmt.ColumnFloat(i)
case sqlite.TypeText: case sqlite.TypeText:
row[columnName] = stmt.ColumnText(i) row[colName] = stmt.ColumnText(i)
case sqlite.TypeBlob: case sqlite.TypeBlob:
blobSize := stmt.ColumnLen(i) blobSize := stmt.ColumnLen(i)
buf := make([]byte, blobSize) buf := make([]byte, blobSize)
row[columnName] = stmt.ColumnBytes(i, buf) row[colName] = stmt.ColumnBytes(i, buf)
case sqlite.TypeNull: case sqlite.TypeNull:
row[columnName] = nil row[colName] = nil
} }
} }
rows = append(rows, row) // No need to copy, this row is used only once
// Add row copy to results
rowCopy := make(map[string]any, len(row))
maps.Copy(rowCopy, row)
rows = append(rows, rowCopy)
return nil return nil
} }
}
// Execute query // Execute query
var execErr error if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
if isQuery || execOpts.Args != nil || execOpts.Named != nil { state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
execErr = sqlitex.Execute(conn, query, execOpts)
} else {
// Use ExecScript for queries without parameters
execErr = sqlitex.ExecScript(conn, query)
}
if execErr != nil {
state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, execErr.Error()))
return -1 return -1
} }
// Return results for query, affected rows for exec // Create result table
if isQuery {
// Create result table with rows
state.NewTable() state.NewTable()
for i, row := range rows { for i, row := range rows {
state.PushNumber(float64(i + 1)) state.PushNumber(float64(i + 1))
if err := state.PushTable(row); err != nil { if err := state.PushTable(row); err != nil {
state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error())) state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1 return -1
} }
state.SetTable(-3) state.SetTable(-3)
} }
} else {
// Return number of affected rows
state.PushNumber(float64(conn.Changes()))
}
return 1 return 1
} }
// luaSQLQuery executes a SQL query and returns results to Lua // sqlExec executes a SQL statement without returning results
func luaSQLQuery(state *luajit.State) int { func sqlExec(state *luajit.State) int {
return sqlOperation(state, true) // Get required parameters
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
state.PushString("sqlite.exec: requires database name and query")
return -1
} }
// luaSQLExec executes a SQL statement without returning results dbName := state.ToString(1)
func luaSQLExec(state *luajit.State) int { query := state.ToString(2)
return sqlOperation(state, false) connID := fmt.Sprintf("temp_%p", &query)
// Get connection
conn, err := getConnection(dbName, connID)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
// For temporary connections, defer release
defer releaseConnection(connID)
// Check if parameters are provided
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
hasPlaceholders := strings.Contains(query, "?") || strings.Contains(query, ":")
// Fast path for simple queries with no parameters
if !hasParams || !hasPlaceholders {
if err := sqlitex.ExecScript(conn, query); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
state.PushNumber(float64(conn.Changes()))
return 1
}
// Create execution options for parameterized query
var execOpts sqlitex.ExecOptions
// Set up parameters
if state.IsTable(3) {
params, err := state.ToTable(3)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: invalid parameters: %s", err.Error()))
return -1
}
// Check for array-style params
if arr, ok := params[""]; ok {
if arrParams, ok := arr.([]any); ok {
execOpts.Args = arrParams
} else if floatArr, ok := arr.([]float64); ok {
args := make([]any, len(floatArr))
for i, v := range floatArr {
args[i] = v
}
execOpts.Args = args
}
} else {
// Named parameters
named := make(map[string]any, len(params))
for k, v := range params {
if len(k) > 0 && k[0] != ':' {
named[":"+k] = v
} else {
named[k] = v
}
}
execOpts.Named = named
}
} else {
// Positional parameters
count := state.GetTop() - 2
args := make([]any, count)
for i := 0; i < count; i++ {
idx := i + 3
switch state.GetType(idx) {
case luajit.TypeNumber:
args[i] = state.ToNumber(idx)
case luajit.TypeString:
args[i] = state.ToString(idx)
case luajit.TypeBoolean:
args[i] = state.ToBoolean(idx)
case luajit.TypeNil:
args[i] = nil
default:
val, err := state.ToValue(idx)
if err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: invalid parameter %d: %s", i+1, err.Error()))
return -1
}
args[i] = val
}
}
execOpts.Args = args
}
// Count the number of placeholders to validate parameter count
if execOpts.Args != nil {
placeholderCount := strings.Count(query, "?")
if len(execOpts.Args) > placeholderCount {
state.PushString(fmt.Sprintf("sqlite.exec: too many parameters provided (%d) for placeholders (%d)",
len(execOpts.Args), placeholderCount))
return -1
}
}
// Execute with parameters
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
// Return affected rows
state.PushNumber(float64(conn.Changes()))
return 1
} }
// RegisterSQLiteFunctions registers SQLite functions with the Lua state // RegisterSQLiteFunctions registers SQLite functions with the Lua state
func RegisterSQLiteFunctions(state *luajit.State) error { func RegisterSQLiteFunctions(state *luajit.State) error {
if err := state.RegisterGoFunction("__sqlite_query", luaSQLQuery); err != nil { if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
return err return err
} }
return state.RegisterGoFunction("__sqlite_exec", luaSQLExec) return state.RegisterGoFunction("__sqlite_exec", sqlExec)
} }