300 lines
7.7 KiB
Lua
300 lines
7.7 KiB
Lua
-- sqlite.lua
|
|
|
|
local function normalize_params(params, ...)
|
|
if type(params) == "table" then return params end
|
|
local args = {...}
|
|
if #args > 0 or params ~= nil then
|
|
table.insert(args, 1, params)
|
|
return args
|
|
end
|
|
return nil
|
|
end
|
|
|
|
local connection_mt = {
|
|
__index = {
|
|
query = function(self, query, params, ...)
|
|
if type(query) ~= "string" then
|
|
error("connection:query: query must be a string", 2)
|
|
end
|
|
|
|
local normalized_params = normalize_params(params, ...)
|
|
return __sqlite_query(self.db_name, query, normalized_params)
|
|
end,
|
|
|
|
exec = function(self, query, params, ...)
|
|
if type(query) ~= "string" then
|
|
error("connection:exec: query must be a string", 2)
|
|
end
|
|
|
|
local normalized_params = normalize_params(params, ...)
|
|
return __sqlite_exec(self.db_name, query, normalized_params)
|
|
end,
|
|
|
|
get_one = function(self, query, params, ...)
|
|
if type(query) ~= "string" then
|
|
error("connection:get_one: query must be a string", 2)
|
|
end
|
|
|
|
local normalized_params = normalize_params(params, ...)
|
|
return __sqlite_get_one(self.db_name, query, normalized_params)
|
|
end,
|
|
|
|
insert = function(self, table_name, data, columns)
|
|
if type(data) ~= "table" then
|
|
error("connection:insert: data must be a table", 2)
|
|
end
|
|
|
|
-- Single object: {col1=val1, col2=val2}
|
|
if data[1] == nil and next(data) ~= nil then
|
|
local cols = table.keys(data)
|
|
local placeholders = table.map(cols, function(_, i) return ":p" .. i end)
|
|
local params = {}
|
|
for i, col in ipairs(cols) do
|
|
params["p" .. i] = data[col]
|
|
end
|
|
|
|
local query = string.format(
|
|
"INSERT INTO %s (%s) VALUES (%s)",
|
|
table_name,
|
|
table.concat(cols, ", "),
|
|
table.concat(placeholders, ", ")
|
|
)
|
|
return self:exec(query, params)
|
|
end
|
|
|
|
-- Array data with columns
|
|
if columns and type(columns) == "table" then
|
|
if #data > 0 and type(data[1]) == "table" then
|
|
-- Multiple rows
|
|
local value_groups = {}
|
|
local params = {}
|
|
local param_idx = 1
|
|
|
|
for _, row in ipairs(data) do
|
|
local row_placeholders = {}
|
|
for j = 1, #columns do
|
|
local param_name = "p" .. param_idx
|
|
table.insert(row_placeholders, ":" .. param_name)
|
|
params[param_name] = row[j]
|
|
param_idx = param_idx + 1
|
|
end
|
|
table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")")
|
|
end
|
|
|
|
local query = string.format(
|
|
"INSERT INTO %s (%s) VALUES %s",
|
|
table_name,
|
|
table.concat(columns, ", "),
|
|
table.concat(value_groups, ", ")
|
|
)
|
|
return self:exec(query, params)
|
|
else
|
|
-- Single row array
|
|
local placeholders = table.map(columns, function(_, i) return ":p" .. i end)
|
|
local params = {}
|
|
for i = 1, #columns do
|
|
params["p" .. i] = data[i]
|
|
end
|
|
|
|
local query = string.format(
|
|
"INSERT INTO %s (%s) VALUES (%s)",
|
|
table_name,
|
|
table.concat(columns, ", "),
|
|
table.concat(placeholders, ", ")
|
|
)
|
|
return self:exec(query, params)
|
|
end
|
|
end
|
|
|
|
-- Array of objects
|
|
if #data > 0 and type(data[1]) == "table" and data[1][1] == nil then
|
|
local cols = table.keys(data[1])
|
|
local value_groups = {}
|
|
local params = {}
|
|
local param_idx = 1
|
|
|
|
for _, row in ipairs(data) do
|
|
local row_placeholders = {}
|
|
for _, col in ipairs(cols) do
|
|
local param_name = "p" .. param_idx
|
|
table.insert(row_placeholders, ":" .. param_name)
|
|
params[param_name] = row[col]
|
|
param_idx = param_idx + 1
|
|
end
|
|
table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")")
|
|
end
|
|
|
|
local query = string.format(
|
|
"INSERT INTO %s (%s) VALUES %s",
|
|
table_name,
|
|
table.concat(cols, ", "),
|
|
table.concat(value_groups, ", ")
|
|
)
|
|
return self:exec(query, params)
|
|
end
|
|
|
|
error("connection:insert: invalid data format", 2)
|
|
end,
|
|
|
|
update = function(self, table_name, data, where, where_params, ...)
|
|
if type(data) ~= "table" or next(data) == nil then
|
|
return 0
|
|
end
|
|
|
|
local sets = {}
|
|
local params = {}
|
|
local param_idx = 1
|
|
|
|
for col, val in pairs(data) do
|
|
local param_name = "p" .. param_idx
|
|
table.insert(sets, col .. " = :" .. param_name)
|
|
params[param_name] = val
|
|
param_idx = param_idx + 1
|
|
end
|
|
|
|
local query = string.format("UPDATE %s SET %s", table_name, table.concat(sets, ", "))
|
|
|
|
if where then
|
|
query = query .. " WHERE " .. where
|
|
if where_params then
|
|
local normalized = normalize_params(where_params, ...)
|
|
if type(normalized) == "table" then
|
|
for k, v in pairs(normalized) do
|
|
if type(k) == "string" then
|
|
params[k] = v
|
|
else
|
|
params["w" .. param_idx] = v
|
|
param_idx = param_idx + 1
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
return self:exec(query, params)
|
|
end,
|
|
|
|
create_table = function(self, table_name, ...)
|
|
local column_definitions = {}
|
|
local index_definitions = {}
|
|
|
|
for _, def_string in ipairs({...}) do
|
|
if type(def_string) == "string" then
|
|
local is_unique = false
|
|
local index_def = def_string
|
|
|
|
if string.starts_with(def_string, "UNIQUE INDEX:") then
|
|
is_unique = true
|
|
index_def = string.trim(def_string:sub(14))
|
|
elseif string.starts_with(def_string, "INDEX:") then
|
|
index_def = string.trim(def_string:sub(7))
|
|
else
|
|
table.insert(column_definitions, def_string)
|
|
goto continue
|
|
end
|
|
|
|
local paren_pos = index_def:find("%(")
|
|
if not paren_pos then goto continue end
|
|
|
|
local index_name = string.trim(index_def:sub(1, paren_pos - 1))
|
|
local columns_part = index_def:sub(paren_pos + 1):match("^(.-)%)%s*$")
|
|
if not columns_part then goto continue end
|
|
|
|
local columns = table.map(string.split(columns_part, ","), string.trim)
|
|
|
|
if #columns > 0 then
|
|
table.insert(index_definitions, {
|
|
name = index_name,
|
|
columns = columns,
|
|
unique = is_unique
|
|
})
|
|
end
|
|
end
|
|
::continue::
|
|
end
|
|
|
|
if #column_definitions == 0 then
|
|
error("connection:create_table: no column definitions specified for table " .. table_name, 2)
|
|
end
|
|
|
|
local statements = {}
|
|
|
|
table.insert(statements, string.format(
|
|
"CREATE TABLE IF NOT EXISTS %s (%s)",
|
|
table_name,
|
|
table.concat(column_definitions, ", ")
|
|
))
|
|
|
|
for _, idx in ipairs(index_definitions) do
|
|
local unique_prefix = idx.unique and "UNIQUE " or ""
|
|
table.insert(statements, string.format(
|
|
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
|
|
unique_prefix,
|
|
idx.name,
|
|
table_name,
|
|
table.concat(idx.columns, ", ")
|
|
))
|
|
end
|
|
|
|
return self:exec(table.concat(statements, ";\n"))
|
|
end,
|
|
|
|
delete = function(self, table_name, where, params, ...)
|
|
local query = "DELETE FROM " .. table_name
|
|
if where then
|
|
query = query .. " WHERE " .. where
|
|
end
|
|
return self:exec(query, normalize_params(params, ...))
|
|
end,
|
|
|
|
exists = function(self, table_name, where, params, ...)
|
|
if type(table_name) ~= "string" then
|
|
error("connection:exists: table_name must be a string", 2)
|
|
end
|
|
|
|
local query = "SELECT 1 FROM " .. table_name
|
|
if where then
|
|
query = query .. " WHERE " .. where
|
|
end
|
|
query = query .. " LIMIT 1"
|
|
|
|
local results = self:query(query, normalize_params(params, ...))
|
|
return #results > 0
|
|
end,
|
|
|
|
begin = function(self)
|
|
return self:exec("BEGIN TRANSACTION")
|
|
end,
|
|
|
|
commit = function(self)
|
|
return self:exec("COMMIT")
|
|
end,
|
|
|
|
rollback = function(self)
|
|
return self:exec("ROLLBACK")
|
|
end,
|
|
|
|
transaction = function(self, callback)
|
|
self:begin()
|
|
local success, result = pcall(callback, self)
|
|
if success then
|
|
self:commit()
|
|
return result
|
|
else
|
|
self:rollback()
|
|
error(result, 2)
|
|
end
|
|
end
|
|
}
|
|
}
|
|
|
|
function sqlite(db_name)
|
|
if type(db_name) ~= "string" then
|
|
error("sqlite: database name must be a string", 2)
|
|
end
|
|
|
|
return setmetatable({
|
|
db_name = db_name
|
|
}, connection_mt)
|
|
end
|