459 lines
12 KiB
Lua
459 lines
12 KiB
Lua
-- Simplified SQLite wrapper
|
|
-- Connection is now lightweight, we don't need to track IDs
|
|
|
|
-- Helper function to handle parameters
|
|
local function handle_params(params, ...)
|
|
-- If params is a table, use it for named parameters
|
|
if type(params) == "table" then
|
|
return params
|
|
end
|
|
|
|
-- If we have varargs, collect them for positional parameters
|
|
local args = {...}
|
|
if #args > 0 or params ~= nil then
|
|
-- Include the first param in the args
|
|
table.insert(args, 1, params)
|
|
return args
|
|
end
|
|
|
|
return nil
|
|
end
|
|
|
|
-- Connection metatable
|
|
local connection_mt = {
|
|
__index = {
|
|
-- Execute a query and return results as a table
|
|
query = function(self, query, params, ...)
|
|
if type(query) ~= "string" then
|
|
error("connection:query: query must be a string", 2)
|
|
end
|
|
|
|
-- Fast path for no parameters
|
|
if params == nil and select('#', ...) == 0 then
|
|
return __sqlite_query(self.db_name, query)
|
|
end
|
|
|
|
-- Handle various parameter types efficiently
|
|
if type(params) == "table" then
|
|
-- If it's an array-like table with numeric keys
|
|
if params[1] ~= nil then
|
|
-- For positional parameters, we want to include the required prefix args
|
|
local args = {self.db_name, query}
|
|
-- Append all parameters
|
|
for i=1, #params do
|
|
args[i+2] = params[i]
|
|
end
|
|
return __sqlite_query(unpack(args))
|
|
else
|
|
-- Named parameters
|
|
return __sqlite_query(self.db_name, query, params)
|
|
end
|
|
else
|
|
-- Variadic parameters, combine with first param
|
|
local args = {self.db_name, query, params}
|
|
local n = select('#', ...)
|
|
for i=1, n do
|
|
args[i+3] = select(i, ...)
|
|
end
|
|
return __sqlite_query(unpack(args))
|
|
end
|
|
end,
|
|
|
|
-- Execute a statement and return affected rows
|
|
exec = function(self, query, params, ...)
|
|
if type(query) ~= "string" then
|
|
error("connection:exec: query must be a string", 2)
|
|
end
|
|
|
|
-- Fast path for no parameters
|
|
if params == nil and select('#', ...) == 0 then
|
|
return __sqlite_exec(self.db_name, query)
|
|
end
|
|
|
|
-- Handle various parameter types efficiently
|
|
if type(params) == "table" then
|
|
-- If it's an array-like table with numeric keys
|
|
if params[1] ~= nil then
|
|
-- For positional parameters, we want to include the required prefix args
|
|
local args = {self.db_name, query}
|
|
-- Append all parameters
|
|
for i=1, #params do
|
|
args[i+2] = params[i]
|
|
end
|
|
return __sqlite_exec(unpack(args))
|
|
else
|
|
-- Named parameters
|
|
return __sqlite_exec(self.db_name, query, params)
|
|
end
|
|
else
|
|
-- Variadic parameters, combine with first param
|
|
local args = {self.db_name, query, params}
|
|
local n = select('#', ...)
|
|
for i=1, n do
|
|
args[i+3] = select(i, ...)
|
|
end
|
|
return __sqlite_exec(unpack(args))
|
|
end
|
|
end,
|
|
|
|
-- Insert a row or multiple rows with a single query
|
|
insert = function(self, table_name, data, columns)
|
|
if type(data) ~= "table" then
|
|
error("connection:insert: data must be a table", 2)
|
|
end
|
|
|
|
-- Case 1: Named columns with array data
|
|
if columns and type(columns) == "table" then
|
|
-- Check if we have multiple rows
|
|
if #data > 0 and type(data[1]) == "table" then
|
|
-- Build a single multi-value INSERT
|
|
local placeholders = {}
|
|
local values = {}
|
|
local params = {}
|
|
local param_index = 1
|
|
|
|
for i, row in ipairs(data) do
|
|
local row_placeholders = {}
|
|
for j, _ in ipairs(columns) do
|
|
local param_name = "p" .. param_index
|
|
table.insert(row_placeholders, ":" .. param_name)
|
|
params[param_name] = row[j]
|
|
param_index = param_index + 1
|
|
end
|
|
table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")")
|
|
end
|
|
|
|
local query = string.format(
|
|
"INSERT INTO %s (%s) VALUES %s",
|
|
table_name,
|
|
table.concat(columns, ", "),
|
|
table.concat(placeholders, ", ")
|
|
)
|
|
|
|
return self:exec(query, params)
|
|
else
|
|
-- Single row with defined columns
|
|
local placeholders = {}
|
|
local params = {}
|
|
|
|
for i, col in ipairs(columns) do
|
|
local param_name = "p" .. i
|
|
table.insert(placeholders, ":" .. param_name)
|
|
params[param_name] = data[i]
|
|
end
|
|
|
|
local query = string.format(
|
|
"INSERT INTO %s (%s) VALUES (%s)",
|
|
table_name,
|
|
table.concat(columns, ", "),
|
|
table.concat(placeholders, ", ")
|
|
)
|
|
|
|
return self:exec(query, params)
|
|
end
|
|
end
|
|
|
|
-- Case 2: Object-style single row {col1=val1, col2=val2}
|
|
if data[1] == nil and next(data) ~= nil then
|
|
local columns = {}
|
|
local placeholders = {}
|
|
local params = {}
|
|
|
|
for col, val in pairs(data) do
|
|
table.insert(columns, col)
|
|
local param_name = "p" .. #columns
|
|
table.insert(placeholders, ":" .. param_name)
|
|
params[param_name] = val
|
|
end
|
|
|
|
local query = string.format(
|
|
"INSERT INTO %s (%s) VALUES (%s)",
|
|
table_name,
|
|
table.concat(columns, ", "),
|
|
table.concat(placeholders, ", ")
|
|
)
|
|
|
|
return self:exec(query, params)
|
|
end
|
|
|
|
-- Case 3: Array of rows without predefined columns
|
|
if #data > 0 and type(data[1]) == "table" then
|
|
-- Extract columns from the first row
|
|
local first_row = data[1]
|
|
local inferred_columns = {}
|
|
|
|
-- Determine if first row is array or object
|
|
local is_array = first_row[1] ~= nil
|
|
|
|
if is_array then
|
|
-- Cannot infer column names from array
|
|
error("connection:insert: column names required for array data", 2)
|
|
else
|
|
-- Get columns from object keys
|
|
for col, _ in pairs(first_row) do
|
|
table.insert(inferred_columns, col)
|
|
end
|
|
|
|
-- Build multi-value INSERT
|
|
local placeholders = {}
|
|
local params = {}
|
|
local param_index = 1
|
|
|
|
for _, row in ipairs(data) do
|
|
local row_placeholders = {}
|
|
for _, col in ipairs(inferred_columns) do
|
|
local param_name = "p" .. param_index
|
|
table.insert(row_placeholders, ":" .. param_name)
|
|
params[param_name] = row[col]
|
|
param_index = param_index + 1
|
|
end
|
|
table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")")
|
|
end
|
|
|
|
local query = string.format(
|
|
"INSERT INTO %s (%s) VALUES %s",
|
|
table_name,
|
|
table.concat(inferred_columns, ", "),
|
|
table.concat(placeholders, ", ")
|
|
)
|
|
|
|
return self:exec(query, params)
|
|
end
|
|
end
|
|
|
|
error("connection:insert: invalid data format", 2)
|
|
end,
|
|
|
|
-- Update rows in a table
|
|
update = function(self, table_name, data, where, where_params, ...)
|
|
if type(data) ~= "table" then
|
|
error("connection:update: data must be a table", 2)
|
|
end
|
|
|
|
-- Fast path for when there's no data
|
|
if next(data) == nil then
|
|
return 0
|
|
end
|
|
|
|
local sets = {}
|
|
local params = {}
|
|
local param_index = 1
|
|
|
|
for col, val in pairs(data) do
|
|
local param_name = "p" .. param_index
|
|
table.insert(sets, col .. " = :" .. param_name)
|
|
params[param_name] = val
|
|
param_index = param_index + 1
|
|
end
|
|
|
|
local query = string.format(
|
|
"UPDATE %s SET %s",
|
|
table_name,
|
|
table.concat(sets, ", ")
|
|
)
|
|
|
|
if where then
|
|
query = query .. " WHERE " .. where
|
|
|
|
if where_params then
|
|
if type(where_params) == "table" then
|
|
-- Handle named parameters in WHERE clause
|
|
for k, v in pairs(where_params) do
|
|
local param_name
|
|
if type(k) == "string" and k:sub(1, 1) == ":" then
|
|
param_name = k:sub(2)
|
|
else
|
|
param_name = "w" .. param_index
|
|
-- Replace the placeholder in the WHERE clause
|
|
where = where:gsub(":" .. k, ":" .. param_name)
|
|
end
|
|
params[param_name] = v
|
|
param_index = param_index + 1
|
|
end
|
|
else
|
|
-- Handle positional parameters (? placeholders)
|
|
local args = {where_params, ...}
|
|
local pos = 1
|
|
local offset = 0
|
|
|
|
-- Replace ? with named parameters
|
|
while true do
|
|
local start_pos, end_pos = where:find("?", pos)
|
|
if not start_pos then break end
|
|
|
|
local param_name = "w" .. param_index
|
|
local replacement = ":" .. param_name
|
|
|
|
where = where:sub(1, start_pos - 1) .. replacement .. where:sub(end_pos + 1)
|
|
|
|
if args[pos - offset] ~= nil then
|
|
params[param_name] = args[pos - offset]
|
|
else
|
|
params[param_name] = nil
|
|
end
|
|
|
|
param_index = param_index + 1
|
|
pos = start_pos + #replacement
|
|
offset = offset + 1
|
|
end
|
|
|
|
query = string.format(
|
|
"UPDATE %s SET %s WHERE %s",
|
|
table_name,
|
|
table.concat(sets, ", "),
|
|
where
|
|
)
|
|
end
|
|
end
|
|
end
|
|
|
|
return self:exec(query, params)
|
|
end,
|
|
|
|
-- Create a new table
|
|
create_table = function(self, table_name, ...)
|
|
local columns = {}
|
|
local indices = {}
|
|
|
|
-- Process all arguments
|
|
for _, def in ipairs({...}) do
|
|
if type(def) == "string" then
|
|
-- Check if it's an index definition
|
|
local index_type, index_def = def:match("^(UNIQUE%s+INDEX:|INDEX:)(.+)")
|
|
|
|
if index_def then
|
|
-- Parse index definition
|
|
local index_name, columns_str = index_def:match("([%w_]+)%(([^)]+)%)")
|
|
|
|
if index_name and columns_str then
|
|
-- Split columns by comma
|
|
local index_columns = {}
|
|
for col in columns_str:gmatch("[^,]+") do
|
|
table.insert(index_columns, col:match("^%s*(.-)%s*$")) -- Trim whitespace
|
|
end
|
|
|
|
table.insert(indices, {
|
|
name = index_name,
|
|
columns = index_columns,
|
|
unique = (index_type == "UNIQUE INDEX:")
|
|
})
|
|
end
|
|
else
|
|
-- Regular column definition
|
|
table.insert(columns, def)
|
|
end
|
|
end
|
|
end
|
|
|
|
if #columns == 0 then
|
|
error("connection:create_table: no columns specified", 2)
|
|
end
|
|
|
|
-- Build combined statement for table and indices
|
|
local statements = {}
|
|
|
|
-- Add the CREATE TABLE statement
|
|
table.insert(statements, string.format(
|
|
"CREATE TABLE IF NOT EXISTS %s (%s)",
|
|
table_name,
|
|
table.concat(columns, ", ")
|
|
))
|
|
|
|
-- Add CREATE INDEX statements
|
|
for _, idx in ipairs(indices) do
|
|
local unique = idx.unique and "UNIQUE " or ""
|
|
|
|
table.insert(statements, string.format(
|
|
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
|
|
unique,
|
|
idx.name,
|
|
table_name,
|
|
table.concat(idx.columns, ", ")
|
|
))
|
|
end
|
|
|
|
-- Execute all statements in a single transaction
|
|
local combined_sql = table.concat(statements, ";\n")
|
|
return self:exec(combined_sql)
|
|
end,
|
|
|
|
-- Delete rows
|
|
delete = function(self, table_name, where, params)
|
|
local query = "DELETE FROM " .. table_name
|
|
|
|
if where then
|
|
query = query .. " WHERE " .. where
|
|
end
|
|
|
|
return self:exec(query, params)
|
|
end,
|
|
|
|
-- Get one row efficiently
|
|
get_one = function(self, query, params, ...)
|
|
if type(query) ~= "string" then
|
|
error("connection:get_one: query must be a string", 2)
|
|
end
|
|
|
|
-- Add LIMIT 1 to query if not already limited
|
|
local limited_query = query
|
|
if not query:lower():match("limit%s+%d+") then
|
|
limited_query = query .. " LIMIT 1"
|
|
end
|
|
|
|
local results
|
|
if select('#', ...) > 0 then
|
|
results = self:query(limited_query, params, ...)
|
|
else
|
|
results = self:query(limited_query, params)
|
|
end
|
|
|
|
return results[1]
|
|
end,
|
|
|
|
-- Begin transaction
|
|
begin = function(self)
|
|
return self:exec("BEGIN TRANSACTION")
|
|
end,
|
|
|
|
-- Commit transaction
|
|
commit = function(self)
|
|
return self:exec("COMMIT")
|
|
end,
|
|
|
|
-- Rollback transaction
|
|
rollback = function(self)
|
|
return self:exec("ROLLBACK")
|
|
end,
|
|
|
|
-- Transaction wrapper function
|
|
transaction = function(self, callback)
|
|
self:begin()
|
|
|
|
local success, result = pcall(function()
|
|
return callback(self)
|
|
end)
|
|
|
|
if success then
|
|
self:commit()
|
|
return result
|
|
else
|
|
self:rollback()
|
|
error(result, 2)
|
|
end
|
|
end
|
|
}
|
|
}
|
|
|
|
-- Create sqlite() function that returns a connection object
|
|
return function(db_name)
|
|
if type(db_name) ~= "string" then
|
|
error("sqlite: database name must be a string", 2)
|
|
end
|
|
|
|
local conn = {
|
|
db_name = db_name
|
|
}
|
|
|
|
return setmetatable(conn, connection_mt)
|
|
end
|