Moonshark/runner/lua/sqlite.lua

459 lines
12 KiB
Lua

-- Simplified SQLite wrapper
-- Connection is now lightweight, we don't need to track IDs
-- Helper function to handle parameters
local function handle_params(params, ...)
-- If params is a table, use it for named parameters
if type(params) == "table" then
return params
end
-- If we have varargs, collect them for positional parameters
local args = {...}
if #args > 0 or params ~= nil then
-- Include the first param in the args
table.insert(args, 1, params)
return args
end
return nil
end
-- Connection metatable
local connection_mt = {
__index = {
-- Execute a query and return results as a table
query = function(self, query, params, ...)
if type(query) ~= "string" then
error("connection:query: query must be a string", 2)
end
-- Fast path for no parameters
if params == nil and select('#', ...) == 0 then
return __sqlite_query(self.db_name, query)
end
-- Handle various parameter types efficiently
if type(params) == "table" then
-- If it's an array-like table with numeric keys
if params[1] ~= nil then
-- For positional parameters, we want to include the required prefix args
local args = {self.db_name, query}
-- Append all parameters
for i=1, #params do
args[i+2] = params[i]
end
return __sqlite_query(unpack(args))
else
-- Named parameters
return __sqlite_query(self.db_name, query, params)
end
else
-- Variadic parameters, combine with first param
local args = {self.db_name, query, params}
local n = select('#', ...)
for i=1, n do
args[i+3] = select(i, ...)
end
return __sqlite_query(unpack(args))
end
end,
-- Execute a statement and return affected rows
exec = function(self, query, params, ...)
if type(query) ~= "string" then
error("connection:exec: query must be a string", 2)
end
-- Fast path for no parameters
if params == nil and select('#', ...) == 0 then
return __sqlite_exec(self.db_name, query)
end
-- Handle various parameter types efficiently
if type(params) == "table" then
-- If it's an array-like table with numeric keys
if params[1] ~= nil then
-- For positional parameters, we want to include the required prefix args
local args = {self.db_name, query}
-- Append all parameters
for i=1, #params do
args[i+2] = params[i]
end
return __sqlite_exec(unpack(args))
else
-- Named parameters
return __sqlite_exec(self.db_name, query, params)
end
else
-- Variadic parameters, combine with first param
local args = {self.db_name, query, params}
local n = select('#', ...)
for i=1, n do
args[i+3] = select(i, ...)
end
return __sqlite_exec(unpack(args))
end
end,
-- Insert a row or multiple rows with a single query
insert = function(self, table_name, data, columns)
if type(data) ~= "table" then
error("connection:insert: data must be a table", 2)
end
-- Case 1: Named columns with array data
if columns and type(columns) == "table" then
-- Check if we have multiple rows
if #data > 0 and type(data[1]) == "table" then
-- Build a single multi-value INSERT
local placeholders = {}
local values = {}
local params = {}
local param_index = 1
for i, row in ipairs(data) do
local row_placeholders = {}
for j, _ in ipairs(columns) do
local param_name = "p" .. param_index
table.insert(row_placeholders, ":" .. param_name)
params[param_name] = row[j]
param_index = param_index + 1
end
table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")")
end
local query = string.format(
"INSERT INTO %s (%s) VALUES %s",
table_name,
table.concat(columns, ", "),
table.concat(placeholders, ", ")
)
return self:exec(query, params)
else
-- Single row with defined columns
local placeholders = {}
local params = {}
for i, col in ipairs(columns) do
local param_name = "p" .. i
table.insert(placeholders, ":" .. param_name)
params[param_name] = data[i]
end
local query = string.format(
"INSERT INTO %s (%s) VALUES (%s)",
table_name,
table.concat(columns, ", "),
table.concat(placeholders, ", ")
)
return self:exec(query, params)
end
end
-- Case 2: Object-style single row {col1=val1, col2=val2}
if data[1] == nil and next(data) ~= nil then
local columns = {}
local placeholders = {}
local params = {}
for col, val in pairs(data) do
table.insert(columns, col)
local param_name = "p" .. #columns
table.insert(placeholders, ":" .. param_name)
params[param_name] = val
end
local query = string.format(
"INSERT INTO %s (%s) VALUES (%s)",
table_name,
table.concat(columns, ", "),
table.concat(placeholders, ", ")
)
return self:exec(query, params)
end
-- Case 3: Array of rows without predefined columns
if #data > 0 and type(data[1]) == "table" then
-- Extract columns from the first row
local first_row = data[1]
local inferred_columns = {}
-- Determine if first row is array or object
local is_array = first_row[1] ~= nil
if is_array then
-- Cannot infer column names from array
error("connection:insert: column names required for array data", 2)
else
-- Get columns from object keys
for col, _ in pairs(first_row) do
table.insert(inferred_columns, col)
end
-- Build multi-value INSERT
local placeholders = {}
local params = {}
local param_index = 1
for _, row in ipairs(data) do
local row_placeholders = {}
for _, col in ipairs(inferred_columns) do
local param_name = "p" .. param_index
table.insert(row_placeholders, ":" .. param_name)
params[param_name] = row[col]
param_index = param_index + 1
end
table.insert(placeholders, "(" .. table.concat(row_placeholders, ", ") .. ")")
end
local query = string.format(
"INSERT INTO %s (%s) VALUES %s",
table_name,
table.concat(inferred_columns, ", "),
table.concat(placeholders, ", ")
)
return self:exec(query, params)
end
end
error("connection:insert: invalid data format", 2)
end,
-- Update rows in a table
update = function(self, table_name, data, where, where_params, ...)
if type(data) ~= "table" then
error("connection:update: data must be a table", 2)
end
-- Fast path for when there's no data
if next(data) == nil then
return 0
end
local sets = {}
local params = {}
local param_index = 1
for col, val in pairs(data) do
local param_name = "p" .. param_index
table.insert(sets, col .. " = :" .. param_name)
params[param_name] = val
param_index = param_index + 1
end
local query = string.format(
"UPDATE %s SET %s",
table_name,
table.concat(sets, ", ")
)
if where then
query = query .. " WHERE " .. where
if where_params then
if type(where_params) == "table" then
-- Handle named parameters in WHERE clause
for k, v in pairs(where_params) do
local param_name
if type(k) == "string" and k:sub(1, 1) == ":" then
param_name = k:sub(2)
else
param_name = "w" .. param_index
-- Replace the placeholder in the WHERE clause
where = where:gsub(":" .. k, ":" .. param_name)
end
params[param_name] = v
param_index = param_index + 1
end
else
-- Handle positional parameters (? placeholders)
local args = {where_params, ...}
local pos = 1
local offset = 0
-- Replace ? with named parameters
while true do
local start_pos, end_pos = where:find("?", pos)
if not start_pos then break end
local param_name = "w" .. param_index
local replacement = ":" .. param_name
where = where:sub(1, start_pos - 1) .. replacement .. where:sub(end_pos + 1)
if args[pos - offset] ~= nil then
params[param_name] = args[pos - offset]
else
params[param_name] = nil
end
param_index = param_index + 1
pos = start_pos + #replacement
offset = offset + 1
end
query = string.format(
"UPDATE %s SET %s WHERE %s",
table_name,
table.concat(sets, ", "),
where
)
end
end
end
return self:exec(query, params)
end,
-- Create a new table
create_table = function(self, table_name, ...)
local columns = {}
local indices = {}
-- Process all arguments
for _, def in ipairs({...}) do
if type(def) == "string" then
-- Check if it's an index definition
local index_type, index_def = def:match("^(UNIQUE%s+INDEX:|INDEX:)(.+)")
if index_def then
-- Parse index definition
local index_name, columns_str = index_def:match("([%w_]+)%(([^)]+)%)")
if index_name and columns_str then
-- Split columns by comma
local index_columns = {}
for col in columns_str:gmatch("[^,]+") do
table.insert(index_columns, col:match("^%s*(.-)%s*$")) -- Trim whitespace
end
table.insert(indices, {
name = index_name,
columns = index_columns,
unique = (index_type == "UNIQUE INDEX:")
})
end
else
-- Regular column definition
table.insert(columns, def)
end
end
end
if #columns == 0 then
error("connection:create_table: no columns specified", 2)
end
-- Build combined statement for table and indices
local statements = {}
-- Add the CREATE TABLE statement
table.insert(statements, string.format(
"CREATE TABLE IF NOT EXISTS %s (%s)",
table_name,
table.concat(columns, ", ")
))
-- Add CREATE INDEX statements
for _, idx in ipairs(indices) do
local unique = idx.unique and "UNIQUE " or ""
table.insert(statements, string.format(
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
unique,
idx.name,
table_name,
table.concat(idx.columns, ", ")
))
end
-- Execute all statements in a single transaction
local combined_sql = table.concat(statements, ";\n")
return self:exec(combined_sql)
end,
-- Delete rows
delete = function(self, table_name, where, params)
local query = "DELETE FROM " .. table_name
if where then
query = query .. " WHERE " .. where
end
return self:exec(query, params)
end,
-- Get one row efficiently
get_one = function(self, query, params, ...)
if type(query) ~= "string" then
error("connection:get_one: query must be a string", 2)
end
-- Add LIMIT 1 to query if not already limited
local limited_query = query
if not query:lower():match("limit%s+%d+") then
limited_query = query .. " LIMIT 1"
end
local results
if select('#', ...) > 0 then
results = self:query(limited_query, params, ...)
else
results = self:query(limited_query, params)
end
return results[1]
end,
-- Begin transaction
begin = function(self)
return self:exec("BEGIN TRANSACTION")
end,
-- Commit transaction
commit = function(self)
return self:exec("COMMIT")
end,
-- Rollback transaction
rollback = function(self)
return self:exec("ROLLBACK")
end,
-- Transaction wrapper function
transaction = function(self, callback)
self:begin()
local success, result = pcall(function()
return callback(self)
end)
if success then
self:commit()
return result
else
self:rollback()
error(result, 2)
end
end
}
}
-- Create sqlite() function that returns a connection object
return function(db_name)
if type(db_name) ~= "string" then
error("sqlite: database name must be a string", 2)
end
local conn = {
db_name = db_name
}
return setmetatable(conn, connection_mt)
end