Moonshark/runner/lua/sqlite.lua

300 lines
7.7 KiB
Lua

-- sqlite.lua
local function normalize_params(params, ...)
if type(params) == "table" then return params end
local args = {...}
if #args > 0 or params ~= nil then
table.insert(args, 1, params)
return args
end
return nil
end
local connection_mt = {
__index = {
query = function(self, query, params, ...)
if type(query) ~= "string" then
error("connection:query: query must be a string", 2)
end
local normalized_params = normalize_params(params, ...)
return __sqlite_query(self.db_name, query, normalized_params)
end,
exec = function(self, query, params, ...)
if type(query) ~= "string" then
error("connection:exec: query must be a string", 2)
end
local normalized_params = normalize_params(params, ...)
return __sqlite_exec(self.db_name, query, normalized_params)
end,
get_one = function(self, query, params, ...)
if type(query) ~= "string" then
error("connection:get_one: query must be a string", 2)
end
local normalized_params = normalize_params(params, ...)
return __sqlite_get_one(self.db_name, query, normalized_params)
end,
insert = function(self, table_name, data, columns)
if type(data) ~= "table" then
error("connection:insert: data must be a table", 2)
end
-- Single object: {col1=val1, col2=val2}
if data[1] == nil and next(data) ~= nil then
local cols = table.keys(data)
local placeholders = table.map(cols, function(_, i) return ":p" .. i end)
local params = {}
for i, col in ipairs(cols) do
params["p" .. i] = data[col]
end
local query = string.format(
"INSERT INTO %s (%s) VALUES (%s)",
table_name,
table.concat(cols, ", "),
table.concat(placeholders, ", ")
)
return self:exec(query, params)
end
-- Array data with columns
if columns and type(columns) == "table" then
if #data > 0 and type(data[1]) == "table" then
-- Multiple rows
local value_groups = {}
local params = {}
local param_idx = 1
for _, row in ipairs(data) do
local row_placeholders = {}
for j = 1, #columns do
local param_name = "p" .. param_idx
table.insert(row_placeholders, ":" .. param_name)
params[param_name] = row[j]
param_idx = param_idx + 1
end
table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")")
end
local query = string.format(
"INSERT INTO %s (%s) VALUES %s",
table_name,
table.concat(columns, ", "),
table.concat(value_groups, ", ")
)
return self:exec(query, params)
else
-- Single row array
local placeholders = table.map(columns, function(_, i) return ":p" .. i end)
local params = {}
for i = 1, #columns do
params["p" .. i] = data[i]
end
local query = string.format(
"INSERT INTO %s (%s) VALUES (%s)",
table_name,
table.concat(columns, ", "),
table.concat(placeholders, ", ")
)
return self:exec(query, params)
end
end
-- Array of objects
if #data > 0 and type(data[1]) == "table" and data[1][1] == nil then
local cols = table.keys(data[1])
local value_groups = {}
local params = {}
local param_idx = 1
for _, row in ipairs(data) do
local row_placeholders = {}
for _, col in ipairs(cols) do
local param_name = "p" .. param_idx
table.insert(row_placeholders, ":" .. param_name)
params[param_name] = row[col]
param_idx = param_idx + 1
end
table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")")
end
local query = string.format(
"INSERT INTO %s (%s) VALUES %s",
table_name,
table.concat(cols, ", "),
table.concat(value_groups, ", ")
)
return self:exec(query, params)
end
error("connection:insert: invalid data format", 2)
end,
update = function(self, table_name, data, where, where_params, ...)
if type(data) ~= "table" or next(data) == nil then
return 0
end
local sets = {}
local params = {}
local param_idx = 1
for col, val in pairs(data) do
local param_name = "p" .. param_idx
table.insert(sets, col .. " = :" .. param_name)
params[param_name] = val
param_idx = param_idx + 1
end
local query = string.format("UPDATE %s SET %s", table_name, table.concat(sets, ", "))
if where then
query = query .. " WHERE " .. where
if where_params then
local normalized = normalize_params(where_params, ...)
if type(normalized) == "table" then
for k, v in pairs(normalized) do
if type(k) == "string" then
params[k] = v
else
params["w" .. param_idx] = v
param_idx = param_idx + 1
end
end
end
end
end
return self:exec(query, params)
end,
create_table = function(self, table_name, ...)
local column_definitions = {}
local index_definitions = {}
for _, def_string in ipairs({...}) do
if type(def_string) == "string" then
local is_unique = false
local index_def = def_string
if string.starts_with(def_string, "UNIQUE INDEX:") then
is_unique = true
index_def = string.trim(def_string:sub(14))
elseif string.starts_with(def_string, "INDEX:") then
index_def = string.trim(def_string:sub(7))
else
table.insert(column_definitions, def_string)
goto continue
end
local paren_pos = index_def:find("%(")
if not paren_pos then goto continue end
local index_name = string.trim(index_def:sub(1, paren_pos - 1))
local columns_part = index_def:sub(paren_pos + 1):match("^(.-)%)%s*$")
if not columns_part then goto continue end
local columns = table.map(string.split(columns_part, ","), string.trim)
if #columns > 0 then
table.insert(index_definitions, {
name = index_name,
columns = columns,
unique = is_unique
})
end
end
::continue::
end
if #column_definitions == 0 then
error("connection:create_table: no column definitions specified for table " .. table_name, 2)
end
local statements = {}
table.insert(statements, string.format(
"CREATE TABLE IF NOT EXISTS %s (%s)",
table_name,
table.concat(column_definitions, ", ")
))
for _, idx in ipairs(index_definitions) do
local unique_prefix = idx.unique and "UNIQUE " or ""
table.insert(statements, string.format(
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
unique_prefix,
idx.name,
table_name,
table.concat(idx.columns, ", ")
))
end
return self:exec(table.concat(statements, ";\n"))
end,
delete = function(self, table_name, where, params, ...)
local query = "DELETE FROM " .. table_name
if where then
query = query .. " WHERE " .. where
end
return self:exec(query, normalize_params(params, ...))
end,
exists = function(self, table_name, where, params, ...)
if type(table_name) ~= "string" then
error("connection:exists: table_name must be a string", 2)
end
local query = "SELECT 1 FROM " .. table_name
if where then
query = query .. " WHERE " .. where
end
query = query .. " LIMIT 1"
local results = self:query(query, normalize_params(params, ...))
return #results > 0
end,
begin = function(self)
return self:exec("BEGIN TRANSACTION")
end,
commit = function(self)
return self:exec("COMMIT")
end,
rollback = function(self)
return self:exec("ROLLBACK")
end,
transaction = function(self, callback)
self:begin()
local success, result = pcall(callback, self)
if success then
self:commit()
return result
else
self:rollback()
error(result, 2)
end
end
}
}
function sqlite(db_name)
if type(db_name) ~= "string" then
error("sqlite: database name must be a string", 2)
end
return setmetatable({
db_name = db_name
}, connection_mt)
end