sqlite improvements
This commit is contained in:
parent
c754877f7d
commit
98b2931d59
@ -27,20 +27,37 @@ local connection_mt = {
|
||||
error("connection:query: query must be a string", 2)
|
||||
end
|
||||
|
||||
-- Handle params (named or positional)
|
||||
local processed_params = handle_params(params, ...)
|
||||
-- Fast path for no parameters
|
||||
if params == nil and select('#', ...) == 0 then
|
||||
return __sqlite_query(self.db_name, query, nil, self.id)
|
||||
end
|
||||
|
||||
-- Call with appropriate arguments
|
||||
if type(processed_params) == "table" and processed_params[1] ~= nil then
|
||||
-- Positional parameters - insert self.db_name and query at the beginning
|
||||
table.insert(processed_params, 1, query)
|
||||
table.insert(processed_params, 1, self.db_name)
|
||||
-- Add connection ID at the end
|
||||
table.insert(processed_params, self.id)
|
||||
return __sqlite_query(unpack(processed_params))
|
||||
-- Handle various parameter types efficiently
|
||||
if type(params) == "table" then
|
||||
-- If it's an array-like table with numeric keys
|
||||
if params[1] ~= nil then
|
||||
-- For positional parameters, we want to include the required prefix args
|
||||
local args = {self.db_name, query}
|
||||
-- Append all parameters
|
||||
for i=1, #params do
|
||||
args[i+2] = params[i]
|
||||
end
|
||||
-- Add connection ID
|
||||
args[#args+1] = self.id
|
||||
return __sqlite_query(unpack(args))
|
||||
else
|
||||
-- Named parameters or no parameters
|
||||
return __sqlite_query(self.db_name, query, processed_params, self.id)
|
||||
-- Named parameters
|
||||
return __sqlite_query(self.db_name, query, params, self.id)
|
||||
end
|
||||
else
|
||||
-- Variadic parameters, combine with first param
|
||||
local args = {self.db_name, query, params}
|
||||
local n = select('#', ...)
|
||||
for i=1, n do
|
||||
args[i+3] = select(i, ...)
|
||||
end
|
||||
args[#args+1] = self.id
|
||||
return __sqlite_query(unpack(args))
|
||||
end
|
||||
end,
|
||||
|
||||
@ -50,20 +67,37 @@ local connection_mt = {
|
||||
error("connection:exec: query must be a string", 2)
|
||||
end
|
||||
|
||||
-- Handle params (named or positional)
|
||||
local processed_params = handle_params(params, ...)
|
||||
-- Fast path for no parameters
|
||||
if params == nil and select('#', ...) == 0 then
|
||||
return __sqlite_exec(self.db_name, query, nil, self.id)
|
||||
end
|
||||
|
||||
-- Call with appropriate arguments
|
||||
if type(processed_params) == "table" and processed_params[1] ~= nil then
|
||||
-- Positional parameters - insert self.db_name and query at the beginning
|
||||
table.insert(processed_params, 1, query)
|
||||
table.insert(processed_params, 1, self.db_name)
|
||||
-- Add connection ID at the end
|
||||
table.insert(processed_params, self.id)
|
||||
return __sqlite_exec(unpack(processed_params))
|
||||
-- Handle various parameter types efficiently
|
||||
if type(params) == "table" then
|
||||
-- If it's an array-like table with numeric keys
|
||||
if params[1] ~= nil then
|
||||
-- For positional parameters, we want to include the required prefix args
|
||||
local args = {self.db_name, query}
|
||||
-- Append all parameters
|
||||
for i=1, #params do
|
||||
args[i+2] = params[i]
|
||||
end
|
||||
-- Add connection ID
|
||||
args[#args+1] = self.id
|
||||
return __sqlite_exec(unpack(args))
|
||||
else
|
||||
-- Named parameters or no parameters
|
||||
return __sqlite_exec(self.db_name, query, processed_params, self.id)
|
||||
-- Named parameters
|
||||
return __sqlite_exec(self.db_name, query, params, self.id)
|
||||
end
|
||||
else
|
||||
-- Variadic parameters, combine with first param
|
||||
local args = {self.db_name, query, params}
|
||||
local n = select('#', ...)
|
||||
for i=1, n do
|
||||
args[i+3] = select(i, ...)
|
||||
end
|
||||
args[#args+1] = self.id
|
||||
return __sqlite_exec(unpack(args))
|
||||
end
|
||||
end,
|
||||
|
||||
@ -79,7 +113,7 @@ local connection_mt = {
|
||||
local index_type, index_def = def:match("^(UNIQUE%s+INDEX:|INDEX:)(.+)")
|
||||
|
||||
if index_def then
|
||||
-- Parse index definition: INDEX:idx_name(col1,col2)
|
||||
-- Parse index definition
|
||||
local index_name, columns_str = index_def:match("([%w_]+)%(([^)]+)%)")
|
||||
|
||||
if index_name and columns_str then
|
||||
@ -106,34 +140,32 @@ local connection_mt = {
|
||||
error("connection:create_table: no columns specified", 2)
|
||||
end
|
||||
|
||||
-- Create the table
|
||||
local query = string.format("CREATE TABLE IF NOT EXISTS %s (%s)",
|
||||
table_name, table.concat(columns, ", "))
|
||||
-- Build combined statement for table and indices
|
||||
local statements = {}
|
||||
|
||||
local result = self:exec(query)
|
||||
|
||||
-- Create indices
|
||||
if #indices > 0 then
|
||||
self:begin()
|
||||
-- Add the CREATE TABLE statement
|
||||
table.insert(statements, string.format(
|
||||
"CREATE TABLE IF NOT EXISTS %s (%s)",
|
||||
table_name,
|
||||
table.concat(columns, ", ")
|
||||
))
|
||||
|
||||
-- Add CREATE INDEX statements
|
||||
for _, idx in ipairs(indices) do
|
||||
local unique = idx.unique and "UNIQUE " or ""
|
||||
|
||||
local index_query = string.format(
|
||||
table.insert(statements, string.format(
|
||||
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
|
||||
unique,
|
||||
idx.name,
|
||||
table_name,
|
||||
table.concat(idx.columns, ", ")
|
||||
)
|
||||
|
||||
self:exec(index_query)
|
||||
))
|
||||
end
|
||||
|
||||
self:commit()
|
||||
end
|
||||
|
||||
return result
|
||||
-- Execute all statements in a single transaction
|
||||
local combined_sql = table.concat(statements, ";\n")
|
||||
return self:exec(combined_sql)
|
||||
end,
|
||||
|
||||
-- Insert a row or multiple rows
|
||||
@ -142,10 +174,44 @@ local connection_mt = {
|
||||
error("connection:insert: data must be a table", 2)
|
||||
end
|
||||
|
||||
-- Case 1: Named columns with array data
|
||||
if columns and type(columns) == "table" then
|
||||
-- Check if we have multiple rows
|
||||
if #data > 0 and type(data[1]) == "table" then
|
||||
-- Build a single multi-value INSERT
|
||||
local placeholders = {}
|
||||
for _ in ipairs(columns) do
|
||||
table.insert(placeholders, "?")
|
||||
local values = {}
|
||||
local params = {}
|
||||
local param_index = 1
|
||||
|
||||
for i, row in ipairs(data) do
|
||||
local row_placeholders = {}
|
||||
for j, _ in ipairs(columns) do
|
||||
local param_name = "p" .. param_index
|
||||
table.insert(row_placeholders, ":" .. param_name)
|
||||
params[param_name] = row[j]
|
||||
param_index = param_index + 1
|
||||
end
|
||||
table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")")
|
||||
end
|
||||
|
||||
local query = string.format(
|
||||
"INSERT INTO %s (%s) VALUES %s",
|
||||
table_name,
|
||||
table.concat(columns, ", "),
|
||||
table.concat(placeholders, ", ")
|
||||
)
|
||||
|
||||
return self:exec(query, params)
|
||||
else
|
||||
-- Single row with defined columns
|
||||
local placeholders = {}
|
||||
local params = {}
|
||||
|
||||
for i, col in ipairs(columns) do
|
||||
local param_name = "p" .. i
|
||||
table.insert(placeholders, ":" .. param_name)
|
||||
params[param_name] = data[i]
|
||||
end
|
||||
|
||||
local query = string.format(
|
||||
@ -155,30 +221,11 @@ local connection_mt = {
|
||||
table.concat(placeholders, ", ")
|
||||
)
|
||||
|
||||
local use_transaction = #data > 1 and type(data[1]) == "table"
|
||||
|
||||
if use_transaction then
|
||||
self:begin()
|
||||
end
|
||||
|
||||
local affected = 0
|
||||
|
||||
if #data > 0 and type(data[1]) == "table" then
|
||||
for _, row in ipairs(data) do
|
||||
local result = self:exec(query, row)
|
||||
affected = affected + result
|
||||
end
|
||||
else
|
||||
affected = self:exec(query, data)
|
||||
end
|
||||
|
||||
if use_transaction then
|
||||
self:commit()
|
||||
end
|
||||
|
||||
return affected
|
||||
return self:exec(query, params)
|
||||
end
|
||||
end
|
||||
|
||||
-- Case 2: Object-style single row {col1=val1, col2=val2}
|
||||
if data[1] == nil and next(data) ~= nil then
|
||||
local columns = {}
|
||||
local placeholders = {}
|
||||
@ -186,8 +233,9 @@ local connection_mt = {
|
||||
|
||||
for col, val in pairs(data) do
|
||||
table.insert(columns, col)
|
||||
table.insert(placeholders, ":" .. col)
|
||||
params[":" .. col] = val
|
||||
local param_name = "p" .. #columns
|
||||
table.insert(placeholders, ":" .. param_name)
|
||||
params[param_name] = val
|
||||
end
|
||||
|
||||
local query = string.format(
|
||||
@ -200,34 +248,74 @@ local connection_mt = {
|
||||
return self:exec(query, params)
|
||||
end
|
||||
|
||||
-- Case 3: Array of rows without predefined columns
|
||||
if #data > 0 and type(data[1]) == "table" then
|
||||
self:begin()
|
||||
local affected = 0
|
||||
-- Extract columns from the first row
|
||||
local first_row = data[1]
|
||||
local inferred_columns = {}
|
||||
|
||||
for _, row in ipairs(data) do
|
||||
local result = self:insert(table_name, row)
|
||||
affected = affected + result
|
||||
-- Determine if first row is array or object
|
||||
local is_array = first_row[1] ~= nil
|
||||
|
||||
if is_array then
|
||||
-- Cannot infer column names from array
|
||||
error("connection:insert: column names required for array data", 2)
|
||||
else
|
||||
-- Get columns from object keys
|
||||
for col, _ in pairs(first_row) do
|
||||
table.insert(inferred_columns, col)
|
||||
end
|
||||
|
||||
self:commit()
|
||||
return affected
|
||||
-- Build multi-value INSERT
|
||||
local placeholders = {}
|
||||
local params = {}
|
||||
local param_index = 1
|
||||
|
||||
for _, row in ipairs(data) do
|
||||
local row_placeholders = {}
|
||||
for _, col in ipairs(inferred_columns) do
|
||||
local param_name = "p" .. param_index
|
||||
table.insert(row_placeholders, ":" .. param_name)
|
||||
params[param_name] = row[col]
|
||||
param_index = param_index + 1
|
||||
end
|
||||
table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")")
|
||||
end
|
||||
|
||||
local query = string.format(
|
||||
"INSERT INTO %s (%s) VALUES %s",
|
||||
table_name,
|
||||
table.concat(inferred_columns, ", "),
|
||||
table.concat(placeholders, ", ")
|
||||
)
|
||||
|
||||
return self:exec(query, params)
|
||||
end
|
||||
end
|
||||
|
||||
error("connection:insert: invalid data format", 2)
|
||||
end,
|
||||
|
||||
-- Update rows
|
||||
update = function(self, table_name, data, where, where_params)
|
||||
update = function(self, table_name, data, where, where_params, ...)
|
||||
if type(data) ~= "table" then
|
||||
error("connection:update: data must be a table", 2)
|
||||
end
|
||||
|
||||
-- Fast path for when there's no data
|
||||
if next(data) == nil then
|
||||
return 0
|
||||
end
|
||||
|
||||
local sets = {}
|
||||
local params = {}
|
||||
local param_index = 1
|
||||
|
||||
for col, val in pairs(data) do
|
||||
table.insert(sets, col .. " = :" .. col)
|
||||
params[col] = val
|
||||
local param_name = "p" .. param_index
|
||||
table.insert(sets, col .. " = :" .. param_name)
|
||||
params[param_name] = val
|
||||
param_index = param_index + 1
|
||||
end
|
||||
|
||||
local query = string.format(
|
||||
@ -240,8 +328,53 @@ local connection_mt = {
|
||||
query = query .. " WHERE " .. where
|
||||
|
||||
if where_params then
|
||||
if type(where_params) == "table" then
|
||||
-- Handle named parameters in WHERE clause
|
||||
for k, v in pairs(where_params) do
|
||||
params[k] = v
|
||||
local param_name
|
||||
if type(k) == "string" and k:sub(1, 1) == ":" then
|
||||
param_name = k:sub(2)
|
||||
else
|
||||
param_name = "w" .. param_index
|
||||
-- Replace the placeholder in the WHERE clause
|
||||
where = where:gsub(":" .. k, ":" .. param_name)
|
||||
end
|
||||
params[param_name] = v
|
||||
param_index = param_index + 1
|
||||
end
|
||||
else
|
||||
-- Handle positional parameters (? placeholders)
|
||||
local args = {where_params, ...}
|
||||
local pos = 1
|
||||
local offset = 0
|
||||
|
||||
-- Replace ? with named parameters
|
||||
while true do
|
||||
local start_pos, end_pos = where:find("?", pos)
|
||||
if not start_pos then break end
|
||||
|
||||
local param_name = "w" .. param_index
|
||||
local replacement = ":" .. param_name
|
||||
|
||||
where = where:sub(1, start_pos - 1) .. replacement .. where:sub(end_pos + 1)
|
||||
|
||||
if args[pos - offset] ~= nil then
|
||||
params[param_name] = args[pos - offset]
|
||||
else
|
||||
params[param_name] = nil
|
||||
end
|
||||
|
||||
param_index = param_index + 1
|
||||
pos = start_pos + #replacement
|
||||
offset = offset + 1
|
||||
end
|
||||
|
||||
query = string.format(
|
||||
"UPDATE %s SET %s WHERE %s",
|
||||
table_name,
|
||||
table.concat(sets, ", "),
|
||||
where
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
@ -260,15 +393,25 @@ local connection_mt = {
|
||||
return self:exec(query, params)
|
||||
end,
|
||||
|
||||
-- Get one row
|
||||
-- Get one row efficiently
|
||||
get_one = function(self, query, params, ...)
|
||||
-- Handle both named and positional parameters
|
||||
if type(query) ~= "string" then
|
||||
error("connection:get_one: query must be a string", 2)
|
||||
end
|
||||
|
||||
-- Add LIMIT 1 to query if not already limited
|
||||
local limited_query = query
|
||||
if not query:lower():match("limit%s+%d+") then
|
||||
limited_query = query .. " LIMIT 1"
|
||||
end
|
||||
|
||||
local results
|
||||
if select('#', ...) > 0 then
|
||||
results = self:query(query, params, ...)
|
||||
results = self:query(limited_query, params, ...)
|
||||
else
|
||||
results = self:query(query, params)
|
||||
results = self:query(limited_query, params)
|
||||
end
|
||||
|
||||
return results[1]
|
||||
end,
|
||||
|
||||
|
477
runner/sqlite.go
477
runner/sqlite.go
@ -13,14 +13,11 @@ import (
|
||||
|
||||
"Moonshark/utils/logger"
|
||||
|
||||
"maps"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// SQLiteConnection tracks an active connection
|
||||
type SQLiteConnection struct {
|
||||
DbName string
|
||||
Conn *sqlite.Conn
|
||||
Pool *sqlitex.Pool
|
||||
}
|
||||
@ -54,7 +51,6 @@ func CleanupSQLite() {
|
||||
sqliteManager.mu.Lock()
|
||||
defer sqliteManager.mu.Unlock()
|
||||
|
||||
// Release all connections and close pools
|
||||
for id, conn := range sqliteManager.activeConns {
|
||||
if conn.Pool != nil {
|
||||
conn.Pool.Put(conn.Conn)
|
||||
@ -79,9 +75,6 @@ func ReleaseActiveConnections(state *luajit.State) {
|
||||
return
|
||||
}
|
||||
|
||||
sqliteManager.mu.Lock()
|
||||
defer sqliteManager.mu.Unlock()
|
||||
|
||||
// Get active connections table from Lua
|
||||
state.GetGlobal("__active_sqlite_connections")
|
||||
if !state.IsTable(-1) {
|
||||
@ -89,6 +82,9 @@ func ReleaseActiveConnections(state *luajit.State) {
|
||||
return
|
||||
}
|
||||
|
||||
sqliteManager.mu.Lock()
|
||||
defer sqliteManager.mu.Unlock()
|
||||
|
||||
// Iterate through active connections
|
||||
state.PushNil() // Start iteration
|
||||
for state.Next(-2) {
|
||||
@ -113,8 +109,8 @@ func ReleaseActiveConnections(state *luajit.State) {
|
||||
state.SetGlobal("__active_sqlite_connections")
|
||||
}
|
||||
|
||||
// getPool returns a connection pool for the specified database
|
||||
func getPool(dbName string) (*sqlitex.Pool, error) {
|
||||
// getConnection returns a connection for the database
|
||||
func getConnection(dbName, connID string) (*sqlite.Conn, error) {
|
||||
if sqliteManager == nil {
|
||||
return nil, errors.New("SQLite not initialized")
|
||||
}
|
||||
@ -125,326 +121,325 @@ func getPool(dbName string) (*sqlitex.Pool, error) {
|
||||
return nil, errors.New("invalid database name")
|
||||
}
|
||||
|
||||
// Check for existing pool with read lock
|
||||
// Check for existing connection
|
||||
sqliteManager.mu.RLock()
|
||||
pool, exists := sqliteManager.pools[dbName]
|
||||
sqliteManager.mu.RUnlock()
|
||||
conn, exists := sqliteManager.activeConns[connID]
|
||||
if exists {
|
||||
return pool, nil
|
||||
sqliteManager.mu.RUnlock()
|
||||
return conn.Conn, nil
|
||||
}
|
||||
sqliteManager.mu.RUnlock()
|
||||
|
||||
// Create new pool with write lock
|
||||
// Get or create pool under write lock
|
||||
sqliteManager.mu.Lock()
|
||||
defer sqliteManager.mu.Unlock()
|
||||
|
||||
// Double check if another goroutine created it
|
||||
if pool, exists = sqliteManager.pools[dbName]; exists {
|
||||
return pool, nil
|
||||
// Double-check if a connection was created while waiting for lock
|
||||
if conn, exists = sqliteManager.activeConns[connID]; exists {
|
||||
return conn.Conn, nil
|
||||
}
|
||||
|
||||
// Create database file path and pool
|
||||
// Get or create pool
|
||||
pool, exists := sqliteManager.pools[dbName]
|
||||
if !exists {
|
||||
dbPath := filepath.Join(sqliteManager.dataDir, dbName+".db")
|
||||
pool, err := sqlitex.NewPool(dbPath, sqlitex.PoolOptions{})
|
||||
var err error
|
||||
pool, err = sqlitex.NewPool(dbPath, sqlitex.PoolOptions{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
sqliteManager.pools[dbName] = pool
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// getConnection returns a connection from the pool
|
||||
func getConnection(dbName string, connID string) (*sqlite.Conn, *sqlitex.Pool, error) {
|
||||
// Check for existing connection first
|
||||
sqliteManager.mu.RLock()
|
||||
conn, exists := sqliteManager.activeConns[connID]
|
||||
sqliteManager.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
return conn.Conn, conn.Pool, nil
|
||||
}
|
||||
|
||||
// Get the pool
|
||||
pool, err := getPool(dbName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Get a connection
|
||||
dbConn, err := pool.Take(context.Background())
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get connection from pool: %w", err)
|
||||
return nil, fmt.Errorf("failed to get connection from pool: %w", err)
|
||||
}
|
||||
|
||||
// Store connection
|
||||
sqliteManager.mu.Lock()
|
||||
sqliteManager.activeConns[connID] = &SQLiteConnection{
|
||||
DbName: dbName,
|
||||
Conn: dbConn,
|
||||
Pool: pool,
|
||||
}
|
||||
sqliteManager.mu.Unlock()
|
||||
|
||||
return dbConn, pool, nil
|
||||
return dbConn, nil
|
||||
}
|
||||
|
||||
// processParams extracts parameters and connection ID from Lua state
|
||||
func processParams(state *luajit.State, defaultConnID string) (params any, connID string, isPositional bool, positionalParams []any, err error) {
|
||||
connID = defaultConnID
|
||||
|
||||
// Check if using positional parameters
|
||||
if state.GetTop() >= 3 && !state.IsTable(3) {
|
||||
isPositional = true
|
||||
paramCount := state.GetTop() - 2 // Count all args after db and query
|
||||
|
||||
// Check if last param is a connection ID
|
||||
lastIdx := paramCount + 2 // db(1) + query(2) + paramCount
|
||||
if paramCount > 0 && state.IsString(lastIdx) && state.GetType(lastIdx-1) != state.GetType(lastIdx) {
|
||||
connID = state.ToString(lastIdx)
|
||||
paramCount-- // Exclude connID from param count
|
||||
// releaseConnection returns a connection to its pool
|
||||
func releaseConnection(connID string) {
|
||||
if sqliteManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Create array for positional parameters
|
||||
positionalParams = make([]any, paramCount)
|
||||
sqliteManager.mu.Lock()
|
||||
defer sqliteManager.mu.Unlock()
|
||||
|
||||
// Collect all parameters
|
||||
for i := 0; i < paramCount; i++ {
|
||||
paramIdx := i + 3 // Params start at index 3
|
||||
switch state.GetType(paramIdx) {
|
||||
case luajit.TypeNumber:
|
||||
positionalParams[i] = state.ToNumber(paramIdx)
|
||||
case luajit.TypeString:
|
||||
positionalParams[i] = state.ToString(paramIdx)
|
||||
case luajit.TypeBoolean:
|
||||
positionalParams[i] = state.ToBoolean(paramIdx)
|
||||
case luajit.TypeNil:
|
||||
positionalParams[i] = nil
|
||||
default:
|
||||
val, errConv := state.ToValue(paramIdx)
|
||||
if errConv != nil {
|
||||
return nil, "", false, nil, fmt.Errorf("failed to convert parameter %d: %w", i+1, errConv)
|
||||
}
|
||||
positionalParams[i] = val
|
||||
}
|
||||
}
|
||||
return nil, connID, isPositional, positionalParams, nil
|
||||
conn, exists := sqliteManager.activeConns[connID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
// Named parameter handling
|
||||
if state.GetTop() >= 4 && !state.IsNil(4) && state.IsString(4) {
|
||||
connID = state.ToString(4)
|
||||
if conn.Pool != nil {
|
||||
conn.Pool.Put(conn.Conn)
|
||||
}
|
||||
delete(sqliteManager.activeConns, connID)
|
||||
}
|
||||
|
||||
// Get table parameters if present
|
||||
if state.GetTop() >= 3 && !state.IsNil(3) && state.IsTable(3) {
|
||||
params, err = state.ToTable(3)
|
||||
}
|
||||
|
||||
return params, connID, isPositional, nil, err
|
||||
}
|
||||
|
||||
// prepareExecOptions prepares SQLite execution options based on parameters
|
||||
func prepareExecOptions(query string, params any, isPositional bool, positionalParams []any) *sqlitex.ExecOptions {
|
||||
execOpts := &sqlitex.ExecOptions{}
|
||||
|
||||
if params == nil && !isPositional {
|
||||
return execOpts
|
||||
}
|
||||
|
||||
// Prepare parameters
|
||||
isArray := false
|
||||
var namedParams map[string]any
|
||||
var arrParams []any
|
||||
|
||||
// Check for array parameters
|
||||
if m, ok := params.(map[string]any); ok {
|
||||
if arr, hasArray := m[""]; hasArray {
|
||||
isArray = true
|
||||
if slice, ok := arr.([]any); ok {
|
||||
arrParams = slice
|
||||
} else if floatSlice, ok := arr.([]float64); ok {
|
||||
arrParams = make([]any, len(floatSlice))
|
||||
for i, v := range floatSlice {
|
||||
arrParams[i] = v
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Process named parameters
|
||||
namedParams = make(map[string]any, len(m))
|
||||
for k, v := range m {
|
||||
if len(k) > 0 && k[0] != ':' {
|
||||
namedParams[":"+k] = v
|
||||
} else {
|
||||
namedParams[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if slice, ok := params.([]any); ok {
|
||||
isArray = true
|
||||
arrParams = slice
|
||||
} else if floatSlice, ok := params.([]float64); ok {
|
||||
isArray = true
|
||||
arrParams = make([]any, len(floatSlice))
|
||||
for i, v := range floatSlice {
|
||||
arrParams[i] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Use positional params if explicitly provided
|
||||
if isPositional {
|
||||
arrParams = positionalParams
|
||||
isArray = true
|
||||
}
|
||||
|
||||
// Limit positional params to actual placeholders
|
||||
if isArray && arrParams != nil {
|
||||
placeholderCount := strings.Count(query, "?")
|
||||
if len(arrParams) > placeholderCount {
|
||||
arrParams = arrParams[:placeholderCount]
|
||||
}
|
||||
execOpts.Args = arrParams
|
||||
} else if namedParams != nil {
|
||||
execOpts.Named = namedParams
|
||||
}
|
||||
|
||||
return execOpts
|
||||
}
|
||||
|
||||
// sqlOperation handles both query and exec operations
|
||||
func sqlOperation(state *luajit.State, isQuery bool) int {
|
||||
operation := "query"
|
||||
if !isQuery {
|
||||
operation = "exec"
|
||||
}
|
||||
|
||||
// Get database name
|
||||
if !state.IsString(1) {
|
||||
state.PushString(fmt.Sprintf("sqlite.%s: database name must be a string", operation))
|
||||
// sqlQuery executes a SQL query and returns results
|
||||
func sqlQuery(state *luajit.State) int {
|
||||
// Get required parameters
|
||||
if state.GetTop() < 3 || !state.IsString(1) || !state.IsString(2) {
|
||||
state.PushString("sqlite.query: requires database name, query, and optional parameters")
|
||||
return -1
|
||||
}
|
||||
|
||||
dbName := state.ToString(1)
|
||||
|
||||
// Get query
|
||||
if !state.IsString(2) {
|
||||
state.PushString(fmt.Sprintf("sqlite.%s: query must be a string", operation))
|
||||
return -1
|
||||
}
|
||||
query := state.ToString(2)
|
||||
|
||||
// Generate a temporary connection ID if needed
|
||||
defaultConnID := fmt.Sprintf("temp_%p", &query)
|
||||
|
||||
// Process parameters and get connection ID
|
||||
params, connID, isPositional, positionalParams, err := processParams(state, defaultConnID)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error()))
|
||||
return -1
|
||||
}
|
||||
connID := fmt.Sprintf("temp_%p", &query)
|
||||
|
||||
// Get connection
|
||||
conn, pool, err := getConnection(dbName, connID)
|
||||
conn, err := getConnection(dbName, connID)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error()))
|
||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Create execution options
|
||||
var execOpts sqlitex.ExecOptions
|
||||
rows := make([]map[string]any, 0, 16)
|
||||
|
||||
// For temporary connections, defer release
|
||||
if strings.HasPrefix(connID, "temp_") {
|
||||
defer func() {
|
||||
sqliteManager.mu.Lock()
|
||||
delete(sqliteManager.activeConns, connID)
|
||||
sqliteManager.mu.Unlock()
|
||||
pool.Put(conn)
|
||||
}()
|
||||
defer releaseConnection(connID)
|
||||
|
||||
// Set up parameters if provided
|
||||
if state.GetTop() >= 3 {
|
||||
if state.IsTable(3) {
|
||||
params, err := state.ToTable(3)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.query: invalid parameters: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Prepare execution options
|
||||
execOpts := prepareExecOptions(query, params, isPositional, positionalParams)
|
||||
// Check for array-style params
|
||||
if arr, ok := params[""]; ok {
|
||||
if arrParams, ok := arr.([]any); ok {
|
||||
execOpts.Args = arrParams
|
||||
} else if floatArr, ok := arr.([]float64); ok {
|
||||
args := make([]any, len(floatArr))
|
||||
for i, v := range floatArr {
|
||||
args[i] = v
|
||||
}
|
||||
execOpts.Args = args
|
||||
}
|
||||
} else {
|
||||
// Named parameters
|
||||
named := make(map[string]any, len(params))
|
||||
for k, v := range params {
|
||||
if len(k) > 0 && k[0] != ':' {
|
||||
named[":"+k] = v
|
||||
} else {
|
||||
named[k] = v
|
||||
}
|
||||
}
|
||||
execOpts.Named = named
|
||||
}
|
||||
} else {
|
||||
// Positional parameters
|
||||
count := state.GetTop() - 2
|
||||
args := make([]any, count)
|
||||
for i := 0; i < count; i++ {
|
||||
idx := i + 3
|
||||
switch state.GetType(idx) {
|
||||
case luajit.TypeNumber:
|
||||
args[i] = state.ToNumber(idx)
|
||||
case luajit.TypeString:
|
||||
args[i] = state.ToString(idx)
|
||||
case luajit.TypeBoolean:
|
||||
args[i] = state.ToBoolean(idx)
|
||||
case luajit.TypeNil:
|
||||
args[i] = nil
|
||||
default:
|
||||
val, err := state.ToValue(idx)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.query: invalid parameter %d: %s", i+1, err.Error()))
|
||||
return -1
|
||||
}
|
||||
args[i] = val
|
||||
}
|
||||
}
|
||||
execOpts.Args = args
|
||||
}
|
||||
}
|
||||
|
||||
// Define rows slice outside the closure
|
||||
var rows []map[string]any
|
||||
|
||||
// For queries, add result function
|
||||
if isQuery {
|
||||
// Set up result function
|
||||
execOpts.ResultFunc = func(stmt *sqlite.Stmt) error {
|
||||
row := make(map[string]any)
|
||||
columnCount := stmt.ColumnCount()
|
||||
|
||||
for i := range columnCount {
|
||||
columnName := stmt.ColumnName(i)
|
||||
colCount := stmt.ColumnCount()
|
||||
|
||||
for i := range colCount {
|
||||
colName := stmt.ColumnName(i)
|
||||
switch stmt.ColumnType(i) {
|
||||
case sqlite.TypeInteger:
|
||||
row[columnName] = stmt.ColumnInt64(i)
|
||||
row[colName] = stmt.ColumnInt64(i)
|
||||
case sqlite.TypeFloat:
|
||||
row[columnName] = stmt.ColumnFloat(i)
|
||||
row[colName] = stmt.ColumnFloat(i)
|
||||
case sqlite.TypeText:
|
||||
row[columnName] = stmt.ColumnText(i)
|
||||
row[colName] = stmt.ColumnText(i)
|
||||
case sqlite.TypeBlob:
|
||||
blobSize := stmt.ColumnLen(i)
|
||||
buf := make([]byte, blobSize)
|
||||
row[columnName] = stmt.ColumnBytes(i, buf)
|
||||
row[colName] = stmt.ColumnBytes(i, buf)
|
||||
case sqlite.TypeNull:
|
||||
row[columnName] = nil
|
||||
row[colName] = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Add row copy to results
|
||||
rowCopy := make(map[string]any, len(row))
|
||||
maps.Copy(rowCopy, row)
|
||||
rows = append(rows, rowCopy)
|
||||
rows = append(rows, row) // No need to copy, this row is used only once
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Execute query
|
||||
var execErr error
|
||||
if isQuery || execOpts.Args != nil || execOpts.Named != nil {
|
||||
execErr = sqlitex.Execute(conn, query, execOpts)
|
||||
} else {
|
||||
// Use ExecScript for queries without parameters
|
||||
execErr = sqlitex.ExecScript(conn, query)
|
||||
}
|
||||
|
||||
if execErr != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, execErr.Error()))
|
||||
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Return results for query, affected rows for exec
|
||||
if isQuery {
|
||||
// Create result table with rows
|
||||
// Create result table
|
||||
state.NewTable()
|
||||
for i, row := range rows {
|
||||
state.PushNumber(float64(i + 1))
|
||||
if err := state.PushTable(row); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.%s: %s", operation, err.Error()))
|
||||
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
state.SetTable(-3)
|
||||
}
|
||||
} else {
|
||||
// Return number of affected rows
|
||||
state.PushNumber(float64(conn.Changes()))
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// luaSQLQuery executes a SQL query and returns results to Lua
|
||||
func luaSQLQuery(state *luajit.State) int {
|
||||
return sqlOperation(state, true)
|
||||
// sqlExec executes a SQL statement without returning results
|
||||
func sqlExec(state *luajit.State) int {
|
||||
// Get required parameters
|
||||
if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
|
||||
state.PushString("sqlite.exec: requires database name and query")
|
||||
return -1
|
||||
}
|
||||
|
||||
// luaSQLExec executes a SQL statement without returning results
|
||||
func luaSQLExec(state *luajit.State) int {
|
||||
return sqlOperation(state, false)
|
||||
dbName := state.ToString(1)
|
||||
query := state.ToString(2)
|
||||
connID := fmt.Sprintf("temp_%p", &query)
|
||||
|
||||
// Get connection
|
||||
conn, err := getConnection(dbName, connID)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// For temporary connections, defer release
|
||||
defer releaseConnection(connID)
|
||||
|
||||
// Check if parameters are provided
|
||||
hasParams := state.GetTop() >= 3 && !state.IsNil(3)
|
||||
hasPlaceholders := strings.Contains(query, "?") || strings.Contains(query, ":")
|
||||
|
||||
// Fast path for simple queries with no parameters
|
||||
if !hasParams || !hasPlaceholders {
|
||||
if err := sqlitex.ExecScript(conn, query); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
state.PushNumber(float64(conn.Changes()))
|
||||
return 1
|
||||
}
|
||||
|
||||
// Create execution options for parameterized query
|
||||
var execOpts sqlitex.ExecOptions
|
||||
|
||||
// Set up parameters
|
||||
if state.IsTable(3) {
|
||||
params, err := state.ToTable(3)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: invalid parameters: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Check for array-style params
|
||||
if arr, ok := params[""]; ok {
|
||||
if arrParams, ok := arr.([]any); ok {
|
||||
execOpts.Args = arrParams
|
||||
} else if floatArr, ok := arr.([]float64); ok {
|
||||
args := make([]any, len(floatArr))
|
||||
for i, v := range floatArr {
|
||||
args[i] = v
|
||||
}
|
||||
execOpts.Args = args
|
||||
}
|
||||
} else {
|
||||
// Named parameters
|
||||
named := make(map[string]any, len(params))
|
||||
for k, v := range params {
|
||||
if len(k) > 0 && k[0] != ':' {
|
||||
named[":"+k] = v
|
||||
} else {
|
||||
named[k] = v
|
||||
}
|
||||
}
|
||||
execOpts.Named = named
|
||||
}
|
||||
} else {
|
||||
// Positional parameters
|
||||
count := state.GetTop() - 2
|
||||
args := make([]any, count)
|
||||
for i := 0; i < count; i++ {
|
||||
idx := i + 3
|
||||
switch state.GetType(idx) {
|
||||
case luajit.TypeNumber:
|
||||
args[i] = state.ToNumber(idx)
|
||||
case luajit.TypeString:
|
||||
args[i] = state.ToString(idx)
|
||||
case luajit.TypeBoolean:
|
||||
args[i] = state.ToBoolean(idx)
|
||||
case luajit.TypeNil:
|
||||
args[i] = nil
|
||||
default:
|
||||
val, err := state.ToValue(idx)
|
||||
if err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: invalid parameter %d: %s", i+1, err.Error()))
|
||||
return -1
|
||||
}
|
||||
args[i] = val
|
||||
}
|
||||
}
|
||||
execOpts.Args = args
|
||||
}
|
||||
|
||||
// Count the number of placeholders to validate parameter count
|
||||
if execOpts.Args != nil {
|
||||
placeholderCount := strings.Count(query, "?")
|
||||
if len(execOpts.Args) > placeholderCount {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: too many parameters provided (%d) for placeholders (%d)",
|
||||
len(execOpts.Args), placeholderCount))
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
// Execute with parameters
|
||||
if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
|
||||
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
|
||||
return -1
|
||||
}
|
||||
|
||||
// Return affected rows
|
||||
state.PushNumber(float64(conn.Changes()))
|
||||
return 1
|
||||
}
|
||||
|
||||
// RegisterSQLiteFunctions registers SQLite functions with the Lua state
|
||||
func RegisterSQLiteFunctions(state *luajit.State) error {
|
||||
if err := state.RegisterGoFunction("__sqlite_query", luaSQLQuery); err != nil {
|
||||
if err := state.RegisterGoFunction("__sqlite_query", sqlQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
return state.RegisterGoFunction("__sqlite_exec", luaSQLExec)
|
||||
return state.RegisterGoFunction("__sqlite_exec", sqlExec)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user