232 lines
5.5 KiB
Lua
232 lines
5.5 KiB
Lua
__active_sqlite_connections = {}
|
|
|
|
-- Helper function to handle parameters
|
|
local function handle_params(params, ...)
|
|
-- If params is a table, use it for named parameters
|
|
if type(params) == "table" then
|
|
return params
|
|
end
|
|
|
|
-- If we have varargs, collect them for positional parameters
|
|
local args = {...}
|
|
if #args > 0 or params ~= nil then
|
|
-- Include the first param in the args
|
|
table.insert(args, 1, params)
|
|
return args
|
|
end
|
|
|
|
return nil
|
|
end
|
|
|
|
-- Connection metatable
|
|
local connection_mt = {
|
|
__index = {
|
|
-- Execute a query and return results as a table
|
|
query = function(self, query, params, ...)
|
|
if type(query) ~= "string" then
|
|
error("connection:query: query must be a string", 2)
|
|
end
|
|
|
|
-- Handle params (named or positional)
|
|
local processed_params = handle_params(params, ...)
|
|
|
|
-- Call with appropriate arguments
|
|
if type(processed_params) == "table" and processed_params[1] ~= nil then
|
|
-- Positional parameters - insert self.db_name and query at the beginning
|
|
table.insert(processed_params, 1, query)
|
|
table.insert(processed_params, 1, self.db_name)
|
|
-- Add connection ID at the end
|
|
table.insert(processed_params, self.id)
|
|
return __sqlite_query(unpack(processed_params))
|
|
else
|
|
-- Named parameters or no parameters
|
|
return __sqlite_query(self.db_name, query, processed_params, self.id)
|
|
end
|
|
end,
|
|
|
|
-- Execute a statement and return affected rows
|
|
exec = function(self, query, params, ...)
|
|
if type(query) ~= "string" then
|
|
error("connection:exec: query must be a string", 2)
|
|
end
|
|
|
|
-- Handle params (named or positional)
|
|
local processed_params = handle_params(params, ...)
|
|
|
|
-- Call with appropriate arguments
|
|
if type(processed_params) == "table" and processed_params[1] ~= nil then
|
|
-- Positional parameters - insert self.db_name and query at the beginning
|
|
table.insert(processed_params, 1, query)
|
|
table.insert(processed_params, 1, self.db_name)
|
|
-- Add connection ID at the end
|
|
table.insert(processed_params, self.id)
|
|
return __sqlite_exec(unpack(processed_params))
|
|
else
|
|
-- Named parameters or no parameters
|
|
return __sqlite_exec(self.db_name, query, processed_params, self.id)
|
|
end
|
|
end,
|
|
|
|
-- Create a new table
|
|
create_table = function(self, table_name, ...)
|
|
local columns = {...}
|
|
|
|
if #columns == 0 then
|
|
error("connection:create_table: no columns specified", 2)
|
|
end
|
|
|
|
local query = string.format("CREATE TABLE IF NOT EXISTS %s (%s)",
|
|
table_name, table.concat(columns, ", "))
|
|
|
|
return self:exec(query)
|
|
end,
|
|
|
|
-- Insert a row or multiple rows
|
|
insert = function(self, table_name, data)
|
|
if type(data) ~= "table" then
|
|
error("connection:insert: data must be a table", 2)
|
|
end
|
|
|
|
-- Single row
|
|
if data[1] == nil and next(data) ~= nil then
|
|
local columns = {}
|
|
local placeholders = {}
|
|
local params = {}
|
|
|
|
for col, val in pairs(data) do
|
|
table.insert(columns, col)
|
|
table.insert(placeholders, ":" .. col)
|
|
params[":" .. col] = val
|
|
end
|
|
|
|
local query = string.format(
|
|
"INSERT INTO %s (%s) VALUES (%s)",
|
|
table_name,
|
|
table.concat(columns, ", "),
|
|
table.concat(placeholders, ", ")
|
|
)
|
|
|
|
return self:exec(query, params)
|
|
end
|
|
|
|
-- Multiple rows
|
|
if #data > 0 and type(data[1]) == "table" then
|
|
local affected = 0
|
|
|
|
for _, row in ipairs(data) do
|
|
local result = self:insert(table_name, row)
|
|
affected = affected + result
|
|
end
|
|
|
|
return affected
|
|
end
|
|
|
|
error("connection:insert: invalid data format", 2)
|
|
end,
|
|
|
|
-- Update rows
|
|
update = function(self, table_name, data, where, where_params)
|
|
if type(data) ~= "table" then
|
|
error("connection:update: data must be a table", 2)
|
|
end
|
|
|
|
local sets = {}
|
|
local params = {}
|
|
|
|
for col, val in pairs(data) do
|
|
table.insert(sets, col .. " = :" .. col)
|
|
params[col] = val
|
|
end
|
|
|
|
local query = string.format(
|
|
"UPDATE %s SET %s",
|
|
table_name,
|
|
table.concat(sets, ", ")
|
|
)
|
|
|
|
if where then
|
|
query = query .. " WHERE " .. where
|
|
|
|
if where_params then
|
|
for k, v in pairs(where_params) do
|
|
params[k] = v
|
|
end
|
|
end
|
|
end
|
|
|
|
return self:exec(query, params)
|
|
end,
|
|
|
|
-- Delete rows
|
|
delete = function(self, table_name, where, params)
|
|
local query = "DELETE FROM " .. table_name
|
|
|
|
if where then
|
|
query = query .. " WHERE " .. where
|
|
end
|
|
|
|
return self:exec(query, params)
|
|
end,
|
|
|
|
-- Get one row
|
|
get_one = function(self, query, params, ...)
|
|
-- Handle both named and positional parameters
|
|
local results
|
|
if select('#', ...) > 0 then
|
|
results = self:query(query, params, ...)
|
|
else
|
|
results = self:query(query, params)
|
|
end
|
|
return results[1]
|
|
end,
|
|
|
|
-- Begin transaction
|
|
begin = function(self)
|
|
return self:exec("BEGIN TRANSACTION")
|
|
end,
|
|
|
|
-- Commit transaction
|
|
commit = function(self)
|
|
return self:exec("COMMIT")
|
|
end,
|
|
|
|
-- Rollback transaction
|
|
rollback = function(self)
|
|
return self:exec("ROLLBACK")
|
|
end,
|
|
|
|
-- Transaction wrapper function
|
|
transaction = function(self, callback)
|
|
self:begin()
|
|
|
|
local success, result = pcall(function()
|
|
return callback(self)
|
|
end)
|
|
|
|
if success then
|
|
self:commit()
|
|
return result
|
|
else
|
|
self:rollback()
|
|
error(result, 2)
|
|
end
|
|
end
|
|
}
|
|
}
|
|
|
|
-- Create sqlite() function that returns a connection object
|
|
return function(db_name)
|
|
if type(db_name) ~= "string" then
|
|
error("sqlite: database name must be a string", 2)
|
|
end
|
|
|
|
local conn = {
|
|
db_name = db_name,
|
|
id = tostring({}):match("table: (.*)") -- unique ID based on table address
|
|
}
|
|
|
|
__active_sqlite_connections[conn.id] = conn
|
|
|
|
return setmetatable(conn, connection_mt)
|
|
end
|