-- sqlite.lua local function normalize_params(params, ...) if type(params) == "table" then return params end local args = {...} if #args > 0 or params ~= nil then table.insert(args, 1, params) return args end return nil end local connection_mt = { __index = { query = function(self, query, params, ...) if type(query) ~= "string" then error("connection:query: query must be a string", 2) end local normalized_params = normalize_params(params, ...) return __sqlite_query(self.db_name, query, normalized_params) end, exec = function(self, query, params, ...) if type(query) ~= "string" then error("connection:exec: query must be a string", 2) end local normalized_params = normalize_params(params, ...) return __sqlite_exec(self.db_name, query, normalized_params) end, get_one = function(self, query, params, ...) if type(query) ~= "string" then error("connection:get_one: query must be a string", 2) end local normalized_params = normalize_params(params, ...) return __sqlite_get_one(self.db_name, query, normalized_params) end, insert = function(self, table_name, data, columns) if type(data) ~= "table" then error("connection:insert: data must be a table", 2) end -- Single object: {col1=val1, col2=val2} if data[1] == nil and next(data) ~= nil then local cols = table.keys(data) local placeholders = table.map(cols, function(_, i) return ":p" .. i end) local params = {} for i, col in ipairs(cols) do params["p" .. i] = data[col] end local query = string.format( "INSERT INTO %s (%s) VALUES (%s)", table_name, table.concat(cols, ", "), table.concat(placeholders, ", ") ) return self:exec(query, params) end -- Array data with columns if columns and type(columns) == "table" then if #data > 0 and type(data[1]) == "table" then -- Multiple rows local value_groups = {} local params = {} local param_idx = 1 for _, row in ipairs(data) do local row_placeholders = {} for j = 1, #columns do local param_name = "p" .. param_idx table.insert(row_placeholders, ":" .. param_name) params[param_name] = row[j] param_idx = param_idx + 1 end table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")") end local query = string.format( "INSERT INTO %s (%s) VALUES %s", table_name, table.concat(columns, ", "), table.concat(value_groups, ", ") ) return self:exec(query, params) else -- Single row array local placeholders = table.map(columns, function(_, i) return ":p" .. i end) local params = {} for i = 1, #columns do params["p" .. i] = data[i] end local query = string.format( "INSERT INTO %s (%s) VALUES (%s)", table_name, table.concat(columns, ", "), table.concat(placeholders, ", ") ) return self:exec(query, params) end end -- Array of objects if #data > 0 and type(data[1]) == "table" and data[1][1] == nil then local cols = table.keys(data[1]) local value_groups = {} local params = {} local param_idx = 1 for _, row in ipairs(data) do local row_placeholders = {} for _, col in ipairs(cols) do local param_name = "p" .. param_idx table.insert(row_placeholders, ":" .. param_name) params[param_name] = row[col] param_idx = param_idx + 1 end table.insert(value_groups, "(" .. table.concat(row_placeholders, ", ") .. ")") end local query = string.format( "INSERT INTO %s (%s) VALUES %s", table_name, table.concat(cols, ", "), table.concat(value_groups, ", ") ) return self:exec(query, params) end error("connection:insert: invalid data format", 2) end, update = function(self, table_name, data, where, where_params, ...) if type(data) ~= "table" or next(data) == nil then return 0 end local sets = {} local params = {} local param_idx = 1 for col, val in pairs(data) do local param_name = "p" .. param_idx table.insert(sets, col .. " = :" .. param_name) params[param_name] = val param_idx = param_idx + 1 end local query = string.format("UPDATE %s SET %s", table_name, table.concat(sets, ", ")) if where then query = query .. " WHERE " .. where if where_params then local normalized = normalize_params(where_params, ...) if type(normalized) == "table" then for k, v in pairs(normalized) do if type(k) == "string" then params[k] = v else params["w" .. param_idx] = v param_idx = param_idx + 1 end end end end end return self:exec(query, params) end, create_table = function(self, table_name, ...) local column_definitions = {} local index_definitions = {} for _, def_string in ipairs({...}) do if type(def_string) == "string" then local is_unique = false local index_def = def_string if string.starts_with(def_string, "UNIQUE INDEX:") then is_unique = true index_def = string.trim(def_string:sub(14)) elseif string.starts_with(def_string, "INDEX:") then index_def = string.trim(def_string:sub(7)) else table.insert(column_definitions, def_string) goto continue end local paren_pos = index_def:find("%(") if not paren_pos then goto continue end local index_name = string.trim(index_def:sub(1, paren_pos - 1)) local columns_part = index_def:sub(paren_pos + 1):match("^(.-)%)%s*$") if not columns_part then goto continue end local columns = table.map(string.split(columns_part, ","), string.trim) if #columns > 0 then table.insert(index_definitions, { name = index_name, columns = columns, unique = is_unique }) end end ::continue:: end if #column_definitions == 0 then error("connection:create_table: no column definitions specified for table " .. table_name, 2) end local statements = {} table.insert(statements, string.format( "CREATE TABLE IF NOT EXISTS %s (%s)", table_name, table.concat(column_definitions, ", ") )) for _, idx in ipairs(index_definitions) do local unique_prefix = idx.unique and "UNIQUE " or "" table.insert(statements, string.format( "CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)", unique_prefix, idx.name, table_name, table.concat(idx.columns, ", ") )) end return self:exec(table.concat(statements, ";\n")) end, delete = function(self, table_name, where, params, ...) local query = "DELETE FROM " .. table_name if where then query = query .. " WHERE " .. where end return self:exec(query, normalize_params(params, ...)) end, exists = function(self, table_name, where, params, ...) if type(table_name) ~= "string" then error("connection:exists: table_name must be a string", 2) end local query = "SELECT 1 FROM " .. table_name if where then query = query .. " WHERE " .. where end query = query .. " LIMIT 1" local results = self:query(query, normalize_params(params, ...)) return #results > 0 end, begin = function(self) return self:exec("BEGIN TRANSACTION") end, commit = function(self) return self:exec("COMMIT") end, rollback = function(self) return self:exec("ROLLBACK") end, transaction = function(self, callback) self:begin() local success, result = pcall(callback, self) if success then self:commit() return result else self:rollback() error(result, 2) end end } } function sqlite(db_name) if type(db_name) ~= "string" then error("sqlite: database name must be a string", 2) end return setmetatable({ db_name = db_name }, connection_mt) end