drop connection tracking to slim down sqlite handling

This commit is contained in:
Sky Johnson 2025-05-10 15:08:56 -05:00
parent 98b2931d59
commit 8f9a9da5a1
2 changed files with 96 additions and 108 deletions

View File

@ -1,4 +1,5 @@
__active_sqlite_connections = {} -- Simplified SQLite wrapper
-- Connection is now lightweight, we don't need to track IDs
-- Helper function to handle parameters -- Helper function to handle parameters
local function handle_params(params, ...) local function handle_params(params, ...)
@ -29,7 +30,7 @@ local connection_mt = {
-- Fast path for no parameters -- Fast path for no parameters
if params == nil and select('#', ...) == 0 then if params == nil and select('#', ...) == 0 then
return __sqlite_query(self.db_name, query, nil, self.id) return __sqlite_query(self.db_name, query)
end end
-- Handle various parameter types efficiently -- Handle various parameter types efficiently
@ -42,12 +43,10 @@ local connection_mt = {
for i=1, #params do for i=1, #params do
args[i+2] = params[i] args[i+2] = params[i]
end end
-- Add connection ID
args[#args+1] = self.id
return __sqlite_query(unpack(args)) return __sqlite_query(unpack(args))
else else
-- Named parameters -- Named parameters
return __sqlite_query(self.db_name, query, params, self.id) return __sqlite_query(self.db_name, query, params)
end end
else else
-- Variadic parameters, combine with first param -- Variadic parameters, combine with first param
@ -56,7 +55,6 @@ local connection_mt = {
for i=1, n do for i=1, n do
args[i+3] = select(i, ...) args[i+3] = select(i, ...)
end end
args[#args+1] = self.id
return __sqlite_query(unpack(args)) return __sqlite_query(unpack(args))
end end
end, end,
@ -69,7 +67,7 @@ local connection_mt = {
-- Fast path for no parameters -- Fast path for no parameters
if params == nil and select('#', ...) == 0 then if params == nil and select('#', ...) == 0 then
return __sqlite_exec(self.db_name, query, nil, self.id) return __sqlite_exec(self.db_name, query)
end end
-- Handle various parameter types efficiently -- Handle various parameter types efficiently
@ -82,12 +80,10 @@ local connection_mt = {
for i=1, #params do for i=1, #params do
args[i+2] = params[i] args[i+2] = params[i]
end end
-- Add connection ID
args[#args+1] = self.id
return __sqlite_exec(unpack(args)) return __sqlite_exec(unpack(args))
else else
-- Named parameters -- Named parameters
return __sqlite_exec(self.db_name, query, params, self.id) return __sqlite_exec(self.db_name, query, params)
end end
else else
-- Variadic parameters, combine with first param -- Variadic parameters, combine with first param
@ -96,79 +92,11 @@ local connection_mt = {
for i=1, n do for i=1, n do
args[i+3] = select(i, ...) args[i+3] = select(i, ...)
end end
args[#args+1] = self.id
return __sqlite_exec(unpack(args)) return __sqlite_exec(unpack(args))
end end
end, end,
-- Create a new table -- Insert a row or multiple rows with a single query
create_table = function(self, table_name, ...)
local columns = {}
local indices = {}
-- Process all arguments
for _, def in ipairs({...}) do
if type(def) == "string" then
-- Check if it's an index definition
local index_type, index_def = def:match("^(UNIQUE%s+INDEX:|INDEX:)(.+)")
if index_def then
-- Parse index definition
local index_name, columns_str = index_def:match("([%w_]+)%(([^)]+)%)")
if index_name and columns_str then
-- Split columns by comma
local index_columns = {}
for col in columns_str:gmatch("[^,]+") do
table.insert(index_columns, col:match("^%s*(.-)%s*$")) -- Trim whitespace
end
table.insert(indices, {
name = index_name,
columns = index_columns,
unique = (index_type == "UNIQUE INDEX:")
})
end
else
-- Regular column definition
table.insert(columns, def)
end
end
end
if #columns == 0 then
error("connection:create_table: no columns specified", 2)
end
-- Build combined statement for table and indices
local statements = {}
-- Add the CREATE TABLE statement
table.insert(statements, string.format(
"CREATE TABLE IF NOT EXISTS %s (%s)",
table_name,
table.concat(columns, ", ")
))
-- Add CREATE INDEX statements
for _, idx in ipairs(indices) do
local unique = idx.unique and "UNIQUE " or ""
table.insert(statements, string.format(
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
unique,
idx.name,
table_name,
table.concat(idx.columns, ", ")
))
end
-- Execute all statements in a single transaction
local combined_sql = table.concat(statements, ";\n")
return self:exec(combined_sql)
end,
-- Insert a row or multiple rows
insert = function(self, table_name, data, columns) insert = function(self, table_name, data, columns)
if type(data) ~= "table" then if type(data) ~= "table" then
error("connection:insert: data must be a table", 2) error("connection:insert: data must be a table", 2)
@ -296,7 +224,7 @@ local connection_mt = {
error("connection:insert: invalid data format", 2) error("connection:insert: invalid data format", 2)
end, end,
-- Update rows -- Update rows in a table
update = function(self, table_name, data, where, where_params, ...) update = function(self, table_name, data, where, where_params, ...)
if type(data) ~= "table" then if type(data) ~= "table" then
error("connection:update: data must be a table", 2) error("connection:update: data must be a table", 2)
@ -382,6 +310,73 @@ local connection_mt = {
return self:exec(query, params) return self:exec(query, params)
end, end,
-- Create a new table
create_table = function(self, table_name, ...)
local columns = {}
local indices = {}
-- Process all arguments
for _, def in ipairs({...}) do
if type(def) == "string" then
-- Check if it's an index definition
local index_type, index_def = def:match("^(UNIQUE%s+INDEX:|INDEX:)(.+)")
if index_def then
-- Parse index definition
local index_name, columns_str = index_def:match("([%w_]+)%(([^)]+)%)")
if index_name and columns_str then
-- Split columns by comma
local index_columns = {}
for col in columns_str:gmatch("[^,]+") do
table.insert(index_columns, col:match("^%s*(.-)%s*$")) -- Trim whitespace
end
table.insert(indices, {
name = index_name,
columns = index_columns,
unique = (index_type == "UNIQUE INDEX:")
})
end
else
-- Regular column definition
table.insert(columns, def)
end
end
end
if #columns == 0 then
error("connection:create_table: no columns specified", 2)
end
-- Build combined statement for table and indices
local statements = {}
-- Add the CREATE TABLE statement
table.insert(statements, string.format(
"CREATE TABLE IF NOT EXISTS %s (%s)",
table_name,
table.concat(columns, ", ")
))
-- Add CREATE INDEX statements
for _, idx in ipairs(indices) do
local unique = idx.unique and "UNIQUE " or ""
table.insert(statements, string.format(
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
unique,
idx.name,
table_name,
table.concat(idx.columns, ", ")
))
end
-- Execute all statements in a single transaction
local combined_sql = table.concat(statements, ";\n")
return self:exec(combined_sql)
end,
-- Delete rows -- Delete rows
delete = function(self, table_name, where, params) delete = function(self, table_name, where, params)
local query = "DELETE FROM " .. table_name local query = "DELETE FROM " .. table_name
@ -456,11 +451,8 @@ return function(db_name)
end end
local conn = { local conn = {
db_name = db_name, db_name = db_name
id = tostring({}):match("table: (.*)") -- unique ID based on table address
} }
__active_sqlite_connections[conn.id] = conn
return setmetatable(conn, connection_mt) return setmetatable(conn, connection_mt)
end end

View File

@ -189,8 +189,8 @@ func releaseConnection(connID string) {
// sqlQuery executes a SQL query and returns results // sqlQuery executes a SQL query and returns results
func sqlQuery(state *luajit.State) int { func sqlQuery(state *luajit.State) int {
// Get required parameters // Get required parameters
if state.GetTop() < 3 || !state.IsString(1) || !state.IsString(2) { if state.GetTop() < 2 || !state.IsString(1) || !state.IsString(2) {
state.PushString("sqlite.query: requires database name, query, and optional parameters") state.PushString("sqlite.query: requires database name and query")
return -1 return -1
} }
@ -204,16 +204,14 @@ func sqlQuery(state *luajit.State) int {
state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error())) state.PushString(fmt.Sprintf("sqlite.query: %s", err.Error()))
return -1 return -1
} }
defer releaseConnection(connID)
// Create execution options // Create execution options
var execOpts sqlitex.ExecOptions var execOpts sqlitex.ExecOptions
rows := make([]map[string]any, 0, 16) rows := make([]map[string]any, 0, 16)
// For temporary connections, defer release
defer releaseConnection(connID)
// Set up parameters if provided // Set up parameters if provided
if state.GetTop() >= 3 { if state.GetTop() >= 3 && !state.IsNil(3) {
if state.IsTable(3) { if state.IsTable(3) {
params, err := state.ToTable(3) params, err := state.ToTable(3)
if err != nil { if err != nil {
@ -248,7 +246,7 @@ func sqlQuery(state *luajit.State) int {
// Positional parameters // Positional parameters
count := state.GetTop() - 2 count := state.GetTop() - 2
args := make([]any, count) args := make([]any, count)
for i := 0; i < count; i++ { for i := range count {
idx := i + 3 idx := i + 3
switch state.GetType(idx) { switch state.GetType(idx) {
case luajit.TypeNumber: case luajit.TypeNumber:
@ -294,7 +292,7 @@ func sqlQuery(state *luajit.State) int {
row[colName] = nil row[colName] = nil
} }
} }
rows = append(rows, row) // No need to copy, this row is used only once rows = append(rows, row)
return nil return nil
} }
@ -336,17 +334,25 @@ func sqlExec(state *luajit.State) int {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1 return -1
} }
// For temporary connections, defer release
defer releaseConnection(connID) defer releaseConnection(connID)
// Check if parameters are provided // Check if parameters are provided
hasParams := state.GetTop() >= 3 && !state.IsNil(3) hasParams := state.GetTop() >= 3 && !state.IsNil(3)
hasPlaceholders := strings.Contains(query, "?") || strings.Contains(query, ":")
// Fast path for multi-statement scripts - use ExecScript
if strings.Contains(query, ";") && !hasParams {
if err := sqlitex.ExecScript(conn, query); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1
}
state.PushNumber(float64(conn.Changes()))
return 1
}
// Fast path for simple queries with no parameters // Fast path for simple queries with no parameters
if !hasParams || !hasPlaceholders { if !hasParams {
if err := sqlitex.ExecScript(conn, query); err != nil { // Use Execute for simple statements without parameters
if err := sqlitex.Execute(conn, query, nil); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))
return -1 return -1
} }
@ -392,7 +398,7 @@ func sqlExec(state *luajit.State) int {
// Positional parameters // Positional parameters
count := state.GetTop() - 2 count := state.GetTop() - 2
args := make([]any, count) args := make([]any, count)
for i := 0; i < count; i++ { for i := range count {
idx := i + 3 idx := i + 3
switch state.GetType(idx) { switch state.GetType(idx) {
case luajit.TypeNumber: case luajit.TypeNumber:
@ -415,16 +421,6 @@ func sqlExec(state *luajit.State) int {
execOpts.Args = args execOpts.Args = args
} }
// Count the number of placeholders to validate parameter count
if execOpts.Args != nil {
placeholderCount := strings.Count(query, "?")
if len(execOpts.Args) > placeholderCount {
state.PushString(fmt.Sprintf("sqlite.exec: too many parameters provided (%d) for placeholders (%d)",
len(execOpts.Args), placeholderCount))
return -1
}
}
// Execute with parameters // Execute with parameters
if err := sqlitex.Execute(conn, query, &execOpts); err != nil { if err := sqlitex.Execute(conn, query, &execOpts); err != nil {
state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error())) state.PushString(fmt.Sprintf("sqlite.exec: %s", err.Error()))