enhance database modules with table utils

This commit is contained in:
Sky Johnson 2025-07-24 10:21:07 -05:00
parent 5551f16bc1
commit 71633b4b4b
3 changed files with 195 additions and 349 deletions

View File

@ -1,4 +1,5 @@
local str = require("string") local str = require("string")
local tbl = require("table")
local mysql = {} local mysql = {}
local Connection = {} local Connection = {}
@ -132,26 +133,20 @@ function Connection:begin()
return nil return nil
end end
-- Enhanced MySQL-specific query builder helpers -- Simplified MySQL-specific query builder helpers
function Connection:insert(table_name, data) function Connection:insert(table_name, data)
if str.is_blank(table_name) then if str.is_blank(table_name) then
error("Table name cannot be empty") error("Table name cannot be empty")
end end
local keys = {} local keys = tbl.keys(data)
local values = {} local values = tbl.values(data)
local placeholders = {} local placeholders = tbl.map(keys, function() return "?" end)
for key, value in pairs(data) do
table.insert(keys, key)
table.insert(values, value)
table.insert(placeholders, "?")
end
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", { local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", {
table = table_name, table = table_name,
columns = str.join(keys, ", "), columns = tbl.concat(keys, ", "),
placeholders = str.join(placeholders, ", ") placeholders = tbl.concat(placeholders, ", ")
}) })
return self:exec(query, unpack(values)) return self:exec(query, unpack(values))
@ -162,28 +157,21 @@ function Connection:upsert(table_name, data, update_data)
error("Table name cannot be empty") error("Table name cannot be empty")
end end
local keys = {} local keys = tbl.keys(data)
local values = {} local values = tbl.values(data)
local placeholders = {} local placeholders = tbl.map(keys, function() return "?" end)
local updates = {}
for key, value in pairs(data) do
table.insert(keys, key)
table.insert(values, value)
table.insert(placeholders, "?")
end
-- Use update_data if provided, otherwise update with same data -- Use update_data if provided, otherwise update with same data
local update_source = update_data or data local update_source = update_data or data
for key, _ in pairs(update_source) do local updates = tbl.map(tbl.keys(update_source), function(key)
table.insert(updates, str.template("${key} = VALUES(${key})", {key = key})) return str.template("${key} = VALUES(${key})", {key = key})
end end)
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON DUPLICATE KEY UPDATE ${updates}", { local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON DUPLICATE KEY UPDATE ${updates}", {
table = table_name, table = table_name,
columns = str.join(keys, ", "), columns = tbl.concat(keys, ", "),
placeholders = str.join(placeholders, ", "), placeholders = tbl.concat(placeholders, ", "),
updates = str.join(updates, ", ") updates = tbl.concat(updates, ", ")
}) })
return self:exec(query, unpack(values)) return self:exec(query, unpack(values))
@ -194,20 +182,14 @@ function Connection:replace(table_name, data)
error("Table name cannot be empty") error("Table name cannot be empty")
end end
local keys = {} local keys = tbl.keys(data)
local values = {} local values = tbl.values(data)
local placeholders = {} local placeholders = tbl.map(keys, function() return "?" end)
for key, value in pairs(data) do
table.insert(keys, key)
table.insert(values, value)
table.insert(placeholders, "?")
end
local query = str.template("REPLACE INTO ${table} (${columns}) VALUES (${placeholders})", { local query = str.template("REPLACE INTO ${table} (${columns}) VALUES (${placeholders})", {
table = table_name, table = table_name,
columns = str.join(keys, ", "), columns = tbl.concat(keys, ", "),
placeholders = str.join(placeholders, ", ") placeholders = tbl.concat(placeholders, ", ")
}) })
return self:exec(query, unpack(values)) return self:exec(query, unpack(values))
@ -221,25 +203,21 @@ function Connection:update(table_name, data, where_clause, ...)
error("WHERE clause cannot be empty for UPDATE") error("WHERE clause cannot be empty for UPDATE")
end end
local sets = {} local keys = tbl.keys(data)
local values = {} local values = tbl.values(data)
local sets = tbl.map(keys, function(key)
for key, value in pairs(data) do return str.template("${key} = ?", {key = key})
table.insert(sets, str.template("${key} = ?", {key = key})) end)
table.insert(values, value)
end
local query = str.template("UPDATE ${table} SET ${sets} WHERE ${where}", { local query = str.template("UPDATE ${table} SET ${sets} WHERE ${where}", {
table = table_name, table = table_name,
sets = str.join(sets, ", "), sets = tbl.concat(sets, ", "),
where = where_clause where = where_clause
}) })
-- Add WHERE clause parameters -- Add WHERE clause parameters
local where_args = {...} local where_args = {...}
for i = 1, #where_args do tbl.extend(values, where_args)
table.insert(values, where_args[i])
end
return self:exec(query, unpack(values)) return self:exec(query, unpack(values))
end end
@ -266,7 +244,7 @@ function Connection:select(table_name, columns, where_clause, ...)
columns = columns or "*" columns = columns or "*"
if type(columns) == "table" then if type(columns) == "table" then
columns = str.join(columns, ", ") columns = tbl.concat(columns, ", ")
end end
local query local query
@ -286,7 +264,7 @@ function Connection:select(table_name, columns, where_clause, ...)
end end
end end
-- Enhanced MySQL schema helpers -- MySQL schema helpers
function Connection:database_exists(database_name) function Connection:database_exists(database_name)
if str.is_blank(database_name) then if str.is_blank(database_name) then
return false return false
@ -441,7 +419,7 @@ function Connection:create_index(index_name, table_name, columns, unique, type)
local unique_clause = unique and "UNIQUE " or "" local unique_clause = unique and "UNIQUE " or ""
local type_clause = type and str.template(" USING ${type}", {type = str.upper(type)}) or "" local type_clause = type and str.template(" USING ${type}", {type = str.upper(type)}) or ""
local columns_str = type(columns) == "table" and str.join(columns, ", ") or tostring(columns) local columns_str = type(columns) == "table" and tbl.concat(columns, ", ") or tostring(columns)
local query = str.template("CREATE ${unique}INDEX ${index} ON ${table} (${columns})${type}", { local query = str.template("CREATE ${unique}INDEX ${index} ON ${table} (${columns})${type}", {
unique = unique_clause, unique = unique_clause,
@ -465,7 +443,7 @@ function Connection:drop_index(index_name, table_name)
return self:exec(query) return self:exec(query)
end end
-- Enhanced MySQL maintenance functions -- MySQL maintenance functions
function Connection:optimize(table_name) function Connection:optimize(table_name)
local table_clause = table_name and str.template(" ${table}", {table = table_name}) or "" local table_clause = table_name and str.template(" ${table}", {table = table_name}) or ""
return self:query(str.template("OPTIMIZE TABLE${table}", {table = table_clause})) return self:query(str.template("OPTIMIZE TABLE${table}", {table = table_clause}))
@ -488,15 +466,7 @@ function Connection:check_table(table_name, options)
local valid_options = {"QUICK", "FAST", "MEDIUM", "EXTENDED", "CHANGED"} local valid_options = {"QUICK", "FAST", "MEDIUM", "EXTENDED", "CHANGED"}
local options_upper = str.upper(options) local options_upper = str.upper(options)
local valid = false if tbl.contains(valid_options, options_upper) then
for _, valid_option in ipairs(valid_options) do
if options_upper == valid_option then
valid = true
break
end
end
if valid then
options_clause = str.template(" ${options}", {options = options_upper}) options_clause = str.template(" ${options}", {options = options_upper})
end end
end end
@ -514,7 +484,7 @@ function Connection:analyze_table(table_name)
return self:query(str.template("ANALYZE TABLE ${table}", {table = table_name})) return self:query(str.template("ANALYZE TABLE ${table}", {table = table_name}))
end end
-- Enhanced MySQL settings and introspection -- MySQL settings and introspection
function Connection:show(what) function Connection:show(what)
if str.is_blank(what) then if str.is_blank(what) then
error("SHOW parameter cannot be empty") error("SHOW parameter cannot be empty")
@ -575,7 +545,7 @@ function Connection:show_table_status(table_name)
end end
end end
-- Enhanced MySQL user and privilege management -- MySQL user and privilege management
function Connection:create_user(username, password, host) function Connection:create_user(username, password, host)
if str.is_blank(username) or str.is_blank(password) then if str.is_blank(username) or str.is_blank(password) then
error("Username and password cannot be empty") error("Username and password cannot be empty")
@ -642,7 +612,7 @@ function Connection:flush_privileges()
return self:exec("FLUSH PRIVILEGES") return self:exec("FLUSH PRIVILEGES")
end end
-- Enhanced MySQL variables and configuration -- MySQL variables and configuration
function Connection:set_variable(name, value, global) function Connection:set_variable(name, value, global)
if str.is_blank(name) then if str.is_blank(name) then
error("Variable name cannot be empty") error("Variable name cannot be empty")
@ -683,7 +653,7 @@ function Connection:show_status(pattern)
end end
end end
-- Enhanced connection management -- Connection management
function mysql.connect(dsn) function mysql.connect(dsn)
if str.is_blank(dsn) then if str.is_blank(dsn) then
error("DSN cannot be empty") error("DSN cannot be empty")
@ -700,7 +670,7 @@ end
mysql.open = mysql.connect mysql.open = mysql.connect
-- Enhanced quick execution functions -- Quick execution functions
function mysql.query(dsn, query_str, ...) function mysql.query(dsn, query_str, ...)
local conn = mysql.connect(dsn) local conn = mysql.connect(dsn)
if not conn then if not conn then
@ -741,7 +711,7 @@ function mysql.query_value(dsn, query_str, ...)
return nil return nil
end end
-- Enhanced migration helpers -- Migration helpers
function mysql.migrate(dsn, migrations, database_name) function mysql.migrate(dsn, migrations, database_name)
local conn = mysql.connect(dsn) local conn = mysql.connect(dsn)
if not conn then if not conn then
@ -813,9 +783,9 @@ function mysql.migrate(dsn, migrations, database_name)
return true return true
end end
-- Result processing utilities (same enhanced versions) -- Simplified result processing utilities
function mysql.to_array(results, column_name) function mysql.to_array(results, column_name)
if not results or #results == 0 then if not results or tbl.is_empty(results) then
return {} return {}
end end
@ -823,15 +793,11 @@ function mysql.to_array(results, column_name)
error("Column name cannot be empty") error("Column name cannot be empty")
end end
local array = {} return tbl.map(results, function(row) return row[column_name] end)
for i, row in ipairs(results) do
array[i] = row[column_name]
end
return array
end end
function mysql.to_map(results, key_column, value_column) function mysql.to_map(results, key_column, value_column)
if not results or #results == 0 then if not results or tbl.is_empty(results) then
return {} return {}
end end
@ -848,7 +814,7 @@ function mysql.to_map(results, key_column, value_column)
end end
function mysql.group_by(results, column_name) function mysql.group_by(results, column_name)
if not results or #results == 0 then if not results or tbl.is_empty(results) then
return {} return {}
end end
@ -856,32 +822,20 @@ function mysql.group_by(results, column_name)
error("Column name cannot be empty") error("Column name cannot be empty")
end end
local groups = {} return tbl.group_by(results, function(row) return row[column_name] end)
for _, row in ipairs(results) do
local key = row[column_name]
if not groups[key] then
groups[key] = {}
end
table.insert(groups[key], row)
end
return groups
end end
-- Enhanced debug helper (same as others) -- Simplified debug helper
function mysql.print_results(results) function mysql.print_results(results)
if not results or #results == 0 then if not results or tbl.is_empty(results) then
print("No results") print("No results")
return return
end end
-- Get column names from first row local columns = tbl.keys(results[1])
local columns = {} tbl.sort(columns)
for col, _ in pairs(results[1]) do
table.insert(columns, col)
end
table.sort(columns)
-- Calculate column widths for better formatting -- Calculate column widths
local widths = {} local widths = {}
for _, col in ipairs(columns) do for _, col in ipairs(columns) do
widths[col] = str.length(col) widths[col] = str.length(col)
@ -894,34 +848,28 @@ function mysql.print_results(results)
end end
end end
-- Print header with proper spacing -- Print header
local header_parts = {} local header_parts = tbl.map(columns, function(col) return str.pad_right(col, widths[col]) end)
local separator_parts = {} local separator_parts = tbl.map(columns, function(col) return str.repeat_("-", widths[col]) end)
for _, col in ipairs(columns) do
table.insert(header_parts, str.pad_right(col, widths[col]))
table.insert(separator_parts, str.repeat_("-", widths[col]))
end
print(str.join(header_parts, " | ")) print(tbl.concat(header_parts, " | "))
print(str.join(separator_parts, "-+-")) print(tbl.concat(separator_parts, "-+-"))
-- Print rows with proper spacing -- Print rows
for _, row in ipairs(results) do for _, row in ipairs(results) do
local value_parts = {} local value_parts = tbl.map(columns, function(col)
for _, col in ipairs(columns) do
local value = tostring(row[col] or "") local value = tostring(row[col] or "")
table.insert(value_parts, str.pad_right(value, widths[col])) return str.pad_right(value, widths[col])
end end)
print(str.join(value_parts, " | ")) print(tbl.concat(value_parts, " | "))
end end
end end
-- Enhanced MySQL-specific utilities -- MySQL-specific utilities
function mysql.escape_string(str_val) function mysql.escape_string(str_val)
if type(str_val) ~= "string" then if type(str_val) ~= "string" then
return tostring(str_val) return tostring(str_val)
end end
-- Basic escaping - in production, use proper escaping functions
return str.replace(str_val, "'", "\\'") return str.replace(str_val, "'", "\\'")
end end
@ -932,7 +880,7 @@ function mysql.escape_identifier(name)
return str.template("`${name}`", {name = str.replace(name, "`", "``")}) return str.template("`${name}`", {name = str.replace(name, "`", "``")})
end end
-- Enhanced DSN builder helper -- DSN builder helper
function mysql.build_dsn(options) function mysql.build_dsn(options)
if type(options) ~= "table" then if type(options) ~= "table" then
error("Options must be a table") error("Options must be a table")
@ -941,7 +889,7 @@ function mysql.build_dsn(options)
local parts = {} local parts = {}
if options.username and not str.is_blank(options.username) then if options.username and not str.is_blank(options.username) then
table.insert(parts, options.username) tbl.insert(parts, options.username)
if options.password and not str.is_blank(options.password) then if options.password and not str.is_blank(options.password) then
parts[#parts] = str.template("${user}:${pass}", { parts[#parts] = str.template("${user}:${pass}", {
user = parts[#parts], user = parts[#parts],
@ -952,9 +900,9 @@ function mysql.build_dsn(options)
end end
if options.protocol and not str.is_blank(options.protocol) then if options.protocol and not str.is_blank(options.protocol) then
table.insert(parts, str.template("${protocol}(", {protocol = options.protocol})) tbl.insert(parts, str.template("${protocol}(", {protocol = options.protocol}))
if options.host and not str.is_blank(options.host) then if options.host and not str.is_blank(options.host) then
table.insert(parts, options.host) tbl.insert(parts, options.host)
if options.port then if options.port then
parts[#parts] = str.template("${host}:${port}", { parts[#parts] = str.template("${host}:${port}", {
host = parts[#parts], host = parts[#parts],
@ -971,33 +919,33 @@ function mysql.build_dsn(options)
port = tostring(options.port) port = tostring(options.port)
}) })
end end
table.insert(parts, host_part .. ")") tbl.insert(parts, host_part .. ")")
end end
if options.database and not str.is_blank(options.database) then if options.database and not str.is_blank(options.database) then
table.insert(parts, str.template("/${database}", {database = options.database})) tbl.insert(parts, str.template("/${database}", {database = options.database}))
end end
-- Add parameters -- Add parameters
local params = {} local params = {}
if options.charset and not str.is_blank(options.charset) then if options.charset and not str.is_blank(options.charset) then
table.insert(params, str.template("charset=${charset}", {charset = options.charset})) tbl.insert(params, str.template("charset=${charset}", {charset = options.charset}))
end end
if options.parseTime ~= nil then if options.parseTime ~= nil then
table.insert(params, str.template("parseTime=${parse}", {parse = tostring(options.parseTime)})) tbl.insert(params, str.template("parseTime=${parse}", {parse = tostring(options.parseTime)}))
end end
if options.timeout and not str.is_blank(options.timeout) then if options.timeout and not str.is_blank(options.timeout) then
table.insert(params, str.template("timeout=${timeout}", {timeout = options.timeout})) tbl.insert(params, str.template("timeout=${timeout}", {timeout = options.timeout}))
end end
if options.tls and not str.is_blank(options.tls) then if options.tls and not str.is_blank(options.tls) then
table.insert(params, str.template("tls=${tls}", {tls = options.tls})) tbl.insert(params, str.template("tls=${tls}", {tls = options.tls}))
end end
if #params > 0 then if #params > 0 then
table.insert(parts, str.template("?${params}", {params = str.join(params, "&")})) tbl.insert(parts, str.template("?${params}", {params = tbl.concat(params, "&")}))
end end
return str.join(parts, "") return tbl.concat(parts, "")
end end
return mysql return mysql

View File

@ -1,4 +1,5 @@
local str = require("string") local str = require("string")
local tbl = require("table")
local postgres = {} local postgres = {}
local Connection = {} local Connection = {}
@ -132,23 +133,20 @@ function Connection:begin()
return nil return nil
end end
-- Enhanced query builder helpers with PostgreSQL parameter numbering -- Simplified PostgreSQL parameter builder
local function build_postgres_params(data) local function build_postgres_params(data)
local keys = {} local keys = tbl.keys(data)
local values = {} local values = tbl.values(data)
local placeholders = {} local placeholders = {}
local param_count = 0
for key, value in pairs(data) do for i = 1, #keys do
table.insert(keys, key) tbl.insert(placeholders, str.template("$${num}", {num = tostring(i)}))
table.insert(values, value)
param_count = param_count + 1
table.insert(placeholders, str.template("$${num}", {num = tostring(param_count)}))
end end
return keys, values, placeholders, param_count return keys, values, placeholders, #keys
end end
-- Simplified query builders using table utilities
function Connection:insert(table_name, data, returning) function Connection:insert(table_name, data, returning)
if str.is_blank(table_name) then if str.is_blank(table_name) then
error("Table name cannot be empty") error("Table name cannot be empty")
@ -158,8 +156,8 @@ function Connection:insert(table_name, data, returning)
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", { local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", {
table = table_name, table = table_name,
columns = str.join(keys, ", "), columns = tbl.concat(keys, ", "),
placeholders = str.join(placeholders, ", ") placeholders = tbl.concat(placeholders, ", ")
}) })
if returning and not str.is_blank(returning) then if returning and not str.is_blank(returning) then
@ -179,27 +177,25 @@ function Connection:upsert(table_name, data, conflict_columns, returning)
end end
local keys, values, placeholders = build_postgres_params(data) local keys, values, placeholders = build_postgres_params(data)
local updates = {} local updates = tbl.map(keys, function(key)
return str.template("${key} = EXCLUDED.${key}", {key = key})
for _, key in ipairs(keys) do end)
table.insert(updates, str.template("${key} = EXCLUDED.${key}", {key = key}))
end
local conflict_clause = "" local conflict_clause = ""
if conflict_columns then if conflict_columns then
if type(conflict_columns) == "string" then if type(conflict_columns) == "string" then
conflict_clause = str.template("(${columns})", {columns = conflict_columns}) conflict_clause = str.template("(${columns})", {columns = conflict_columns})
else else
conflict_clause = str.template("(${columns})", {columns = str.join(conflict_columns, ", ")}) conflict_clause = str.template("(${columns})", {columns = tbl.concat(conflict_columns, ", ")})
end end
end end
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON CONFLICT ${conflict} DO UPDATE SET ${updates}", { local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON CONFLICT ${conflict} DO UPDATE SET ${updates}", {
table = table_name, table = table_name,
columns = str.join(keys, ", "), columns = tbl.concat(keys, ", "),
placeholders = str.join(placeholders, ", "), placeholders = tbl.concat(placeholders, ", "),
conflict = conflict_clause, conflict = conflict_clause,
updates = str.join(updates, ", ") updates = tbl.concat(updates, ", ")
}) })
if returning and not str.is_blank(returning) then if returning and not str.is_blank(returning) then
@ -221,17 +217,16 @@ function Connection:update(table_name, data, where_clause, returning, ...)
error("WHERE clause cannot be empty for UPDATE") error("WHERE clause cannot be empty for UPDATE")
end end
local sets = {} local keys = tbl.keys(data)
local values = {} local values = tbl.values(data)
local param_count = 0 local param_count = #keys
for key, value in pairs(data) do local sets = {}
param_count = param_count + 1 for i, key in ipairs(keys) do
table.insert(sets, str.template("${key} = $${num}", { tbl.insert(sets, str.template("${key} = $${num}", {
key = key, key = key,
num = tostring(param_count) num = tostring(i)
})) }))
table.insert(values, value)
end end
-- Handle WHERE clause parameters -- Handle WHERE clause parameters
@ -239,15 +234,14 @@ function Connection:update(table_name, data, where_clause, returning, ...)
local where_clause_with_params = where_clause local where_clause_with_params = where_clause
for i = 1, #where_args do for i = 1, #where_args do
param_count = param_count + 1 param_count = param_count + 1
table.insert(values, where_args[i]) tbl.insert(values, where_args[i])
-- Replace ? with numbered parameter if needed
where_clause_with_params = str.replace(where_clause_with_params, "?", where_clause_with_params = str.replace(where_clause_with_params, "?",
str.template("$${num}", {num = tostring(param_count)}), 1) str.template("$${num}", {num = tostring(param_count)}), 1)
end end
local query = str.template("UPDATE ${table} SET ${sets} WHERE ${where}", { local query = str.template("UPDATE ${table} SET ${sets} WHERE ${where}", {
table = table_name, table = table_name,
sets = str.join(sets, ", "), sets = tbl.concat(sets, ", "),
where = where_clause_with_params where = where_clause_with_params
}) })
@ -275,7 +269,7 @@ function Connection:delete(table_name, where_clause, returning, ...)
local values = {} local values = {}
local where_clause_with_params = where_clause local where_clause_with_params = where_clause
for i = 1, #where_args do for i = 1, #where_args do
table.insert(values, where_args[i]) tbl.insert(values, where_args[i])
where_clause_with_params = str.replace(where_clause_with_params, "?", where_clause_with_params = str.replace(where_clause_with_params, "?",
str.template("$${num}", {num = tostring(i)}), 1) str.template("$${num}", {num = tostring(i)}), 1)
end end
@ -303,7 +297,7 @@ function Connection:select(table_name, columns, where_clause, ...)
columns = columns or "*" columns = columns or "*"
if type(columns) == "table" then if type(columns) == "table" then
columns = str.join(columns, ", ") columns = tbl.concat(columns, ", ")
end end
local query local query
@ -313,7 +307,7 @@ function Connection:select(table_name, columns, where_clause, ...)
local values = {} local values = {}
local where_clause_with_params = where_clause local where_clause_with_params = where_clause
for i = 1, #where_args do for i = 1, #where_args do
table.insert(values, where_args[i]) tbl.insert(values, where_args[i])
where_clause_with_params = str.replace(where_clause_with_params, "?", where_clause_with_params = str.replace(where_clause_with_params, "?",
str.template("$${num}", {num = tostring(i)}), 1) str.template("$${num}", {num = tostring(i)}), 1)
end end
@ -418,7 +412,7 @@ function Connection:create_index(index_name, table_name, columns, unique, method
local unique_clause = unique and "UNIQUE " or "" local unique_clause = unique and "UNIQUE " or ""
local method_clause = method and str.template(" USING ${method}", {method = str.upper(method)}) or "" local method_clause = method and str.template(" USING ${method}", {method = str.upper(method)}) or ""
local columns_str = type(columns) == "table" and str.join(columns, ", ") or tostring(columns) local columns_str = type(columns) == "table" and tbl.concat(columns, ", ") or tostring(columns)
local query = str.template("CREATE ${unique}INDEX IF NOT EXISTS ${index} ON ${table}${method} (${columns})", { local query = str.template("CREATE ${unique}INDEX IF NOT EXISTS ${index} ON ${table}${method} (${columns})", {
unique = unique_clause, unique = unique_clause,
@ -443,7 +437,7 @@ function Connection:drop_index(index_name, cascade)
return self:exec(query) return self:exec(query)
end end
-- Enhanced PostgreSQL-specific functions -- PostgreSQL-specific functions
function Connection:vacuum(table_name, analyze) function Connection:vacuum(table_name, analyze)
local analyze_clause = analyze and " ANALYZE" or "" local analyze_clause = analyze and " ANALYZE" or ""
local table_clause = table_name and str.template(" ${table}", {table = table_name}) or "" local table_clause = table_name and str.template(" ${table}", {table = table_name}) or ""
@ -467,15 +461,7 @@ function Connection:reindex(name, type)
local valid_types = {"INDEX", "TABLE", "SCHEMA", "DATABASE", "SYSTEM"} local valid_types = {"INDEX", "TABLE", "SCHEMA", "DATABASE", "SYSTEM"}
local type_upper = str.upper(type) local type_upper = str.upper(type)
local valid = false if not tbl.contains(valid_types, type_upper) then
for _, valid_type in ipairs(valid_types) do
if type_upper == valid_type then
valid = true
break
end
end
if not valid then
error(str.template("Invalid REINDEX type: ${type}", {type = type})) error(str.template("Invalid REINDEX type: ${type}", {type = type}))
end end
@ -539,7 +525,7 @@ function Connection:describe_table(table_name, schema_name)
]], str.trim(schema_name), str.trim(table_name)) ]], str.trim(schema_name), str.trim(table_name))
end end
-- Enhanced JSON/JSONB helpers -- JSON/JSONB helpers
function Connection:json_extract(column, path) function Connection:json_extract(column, path)
if str.is_blank(column) or str.is_blank(path) then if str.is_blank(column) or str.is_blank(path) then
error("Column and path cannot be empty") error("Column and path cannot be empty")
@ -568,7 +554,7 @@ function Connection:jsonb_contained_by(column, value)
return str.template("${column} <@ '${value}'", {column = column, value = value}) return str.template("${column} <@ '${value}'", {column = column, value = value})
end end
-- Enhanced Array helpers -- Array helpers
function Connection:array_contains(column, value) function Connection:array_contains(column, value)
if str.is_blank(column) then if str.is_blank(column) then
error("Column cannot be empty") error("Column cannot be empty")
@ -583,15 +569,13 @@ function Connection:array_length(column)
return str.template("array_length(${column}, 1)", {column = column}) return str.template("array_length(${column}, 1)", {column = column})
end end
-- Enhanced connection management with DSN parsing -- Connection management
function postgres.parse_dsn(dsn) function postgres.parse_dsn(dsn)
if str.is_blank(dsn) then if str.is_blank(dsn) then
return nil, "DSN cannot be empty" return nil, "DSN cannot be empty"
end end
local parts = {} local parts = {}
-- Split by spaces and handle key=value pairs
for pair in str.trim(dsn):gmatch("[^%s]+") do for pair in str.trim(dsn):gmatch("[^%s]+") do
local key, value = pair:match("([^=]+)=(.+)") local key, value = pair:match("([^=]+)=(.+)")
if key and value then if key and value then
@ -618,7 +602,7 @@ end
postgres.open = postgres.connect postgres.open = postgres.connect
-- Enhanced quick execution functions -- Quick execution functions
function postgres.query(dsn, query_str, ...) function postgres.query(dsn, query_str, ...)
local conn = postgres.connect(dsn) local conn = postgres.connect(dsn)
if not conn then if not conn then
@ -659,7 +643,7 @@ function postgres.query_value(dsn, query_str, ...)
return nil return nil
end end
-- Enhanced migration helpers -- Migration helpers
function postgres.migrate(dsn, migrations, schema) function postgres.migrate(dsn, migrations, schema)
schema = schema or "public" schema = schema or "public"
local conn = postgres.connect(dsn) local conn = postgres.connect(dsn)
@ -667,7 +651,6 @@ function postgres.migrate(dsn, migrations, schema)
error("Failed to connect to PostgreSQL database for migration") error("Failed to connect to PostgreSQL database for migration")
end end
-- Create migrations table
conn:create_table("_migrations", conn:create_table("_migrations",
"id SERIAL PRIMARY KEY, name TEXT UNIQUE NOT NULL, applied_at TIMESTAMPTZ DEFAULT NOW()") "id SERIAL PRIMARY KEY, name TEXT UNIQUE NOT NULL, applied_at TIMESTAMPTZ DEFAULT NOW()")
@ -687,7 +670,6 @@ function postgres.migrate(dsn, migrations, schema)
break break
end end
-- Check if migration already applied
local existing = conn:query_value("SELECT id FROM _migrations WHERE name = $1", local existing = conn:query_value("SELECT id FROM _migrations WHERE name = $1",
str.trim(migration.name)) str.trim(migration.name))
if not existing then if not existing then
@ -727,9 +709,9 @@ function postgres.migrate(dsn, migrations, schema)
return true return true
end end
-- Result processing utilities (same enhanced versions as SQLite) -- Simplified result processing utilities
function postgres.to_array(results, column_name) function postgres.to_array(results, column_name)
if not results or #results == 0 then if not results or tbl.is_empty(results) then
return {} return {}
end end
@ -737,15 +719,11 @@ function postgres.to_array(results, column_name)
error("Column name cannot be empty") error("Column name cannot be empty")
end end
local array = {} return tbl.map(results, function(row) return row[column_name] end)
for i, row in ipairs(results) do
array[i] = row[column_name]
end
return array
end end
function postgres.to_map(results, key_column, value_column) function postgres.to_map(results, key_column, value_column)
if not results or #results == 0 then if not results or tbl.is_empty(results) then
return {} return {}
end end
@ -762,7 +740,7 @@ function postgres.to_map(results, key_column, value_column)
end end
function postgres.group_by(results, column_name) function postgres.group_by(results, column_name)
if not results or #results == 0 then if not results or tbl.is_empty(results) then
return {} return {}
end end
@ -770,32 +748,20 @@ function postgres.group_by(results, column_name)
error("Column name cannot be empty") error("Column name cannot be empty")
end end
local groups = {} return tbl.group_by(results, function(row) return row[column_name] end)
for _, row in ipairs(results) do
local key = row[column_name]
if not groups[key] then
groups[key] = {}
end
table.insert(groups[key], row)
end
return groups
end end
-- Enhanced debug helper (same as SQLite) -- Simplified debug helper
function postgres.print_results(results) function postgres.print_results(results)
if not results or #results == 0 then if not results or tbl.is_empty(results) then
print("No results") print("No results")
return return
end end
-- Get column names from first row local columns = tbl.keys(results[1])
local columns = {} tbl.sort(columns)
for col, _ in pairs(results[1]) do
table.insert(columns, col)
end
table.sort(columns)
-- Calculate column widths for better formatting -- Calculate column widths
local widths = {} local widths = {}
for _, col in ipairs(columns) do for _, col in ipairs(columns) do
widths[col] = str.length(col) widths[col] = str.length(col)
@ -808,25 +774,20 @@ function postgres.print_results(results)
end end
end end
-- Print header with proper spacing -- Print header
local header_parts = {} local header_parts = tbl.map(columns, function(col) return str.pad_right(col, widths[col]) end)
local separator_parts = {} local separator_parts = tbl.map(columns, function(col) return str.repeat_("-", widths[col]) end)
for _, col in ipairs(columns) do
table.insert(header_parts, str.pad_right(col, widths[col]))
table.insert(separator_parts, str.repeat_("-", widths[col]))
end
print(str.join(header_parts, " | ")) print(tbl.concat(header_parts, " | "))
print(str.join(separator_parts, "-+-")) print(tbl.concat(separator_parts, "-+-"))
-- Print rows with proper spacing -- Print rows
for _, row in ipairs(results) do for _, row in ipairs(results) do
local value_parts = {} local value_parts = tbl.map(columns, function(col)
for _, col in ipairs(columns) do
local value = tostring(row[col] or "") local value = tostring(row[col] or "")
table.insert(value_parts, str.pad_right(value, widths[col])) return str.pad_right(value, widths[col])
end end)
print(str.join(value_parts, " | ")) print(tbl.concat(value_parts, " | "))
end end
end end

View File

@ -1,4 +1,5 @@
local str = require("string") local str = require("string")
local tbl = require("table")
local sqlite = {} local sqlite = {}
local Connection = {} local Connection = {}
@ -112,26 +113,20 @@ function Connection:begin()
return nil return nil
end end
-- Enhanced query builder helpers with string utilities -- Simplified query builders using table utilities
function Connection:insert(table_name, data) function Connection:insert(table_name, data)
if str.is_blank(table_name) then if str.is_blank(table_name) then
error("Table name cannot be empty") error("Table name cannot be empty")
end end
local keys = {} local keys = tbl.keys(data)
local values = {} local values = tbl.values(data)
local placeholders = {} local placeholders = tbl.map(keys, function() return "?" end)
for key, value in pairs(data) do
table.insert(keys, key)
table.insert(values, value)
table.insert(placeholders, "?")
end
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", { local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", {
table = table_name, table = table_name,
columns = str.join(keys, ", "), columns = tbl.concat(keys, ", "),
placeholders = str.join(placeholders, ", ") placeholders = tbl.concat(placeholders, ", ")
}) })
return self:exec(query, unpack(values)) return self:exec(query, unpack(values))
@ -142,33 +137,28 @@ function Connection:upsert(table_name, data, conflict_columns)
error("Table name cannot be empty") error("Table name cannot be empty")
end end
local keys = {} local keys = tbl.keys(data)
local values = {} local values = tbl.values(data)
local placeholders = {} local placeholders = tbl.map(keys, function() return "?" end)
local updates = {} local updates = tbl.map(keys, function(key)
return str.template("${key} = excluded.${key}", {key = key})
for key, value in pairs(data) do end)
table.insert(keys, key)
table.insert(values, value)
table.insert(placeholders, "?")
table.insert(updates, str.template("${key} = excluded.${key}", {key = key}))
end
local conflict_clause = "" local conflict_clause = ""
if conflict_columns then if conflict_columns then
if type(conflict_columns) == "string" then if type(conflict_columns) == "string" then
conflict_clause = str.template("(${columns})", {columns = conflict_columns}) conflict_clause = str.template("(${columns})", {columns = conflict_columns})
else else
conflict_clause = str.template("(${columns})", {columns = str.join(conflict_columns, ", ")}) conflict_clause = str.template("(${columns})", {columns = tbl.concat(conflict_columns, ", ")})
end end
end end
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON CONFLICT ${conflict} DO UPDATE SET ${updates}", { local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON CONFLICT ${conflict} DO UPDATE SET ${updates}", {
table = table_name, table = table_name,
columns = str.join(keys, ", "), columns = tbl.concat(keys, ", "),
placeholders = str.join(placeholders, ", "), placeholders = tbl.concat(placeholders, ", "),
conflict = conflict_clause, conflict = conflict_clause,
updates = str.join(updates, ", ") updates = tbl.concat(updates, ", ")
}) })
return self:exec(query, unpack(values)) return self:exec(query, unpack(values))
@ -182,25 +172,21 @@ function Connection:update(table_name, data, where_clause, ...)
error("WHERE clause cannot be empty for UPDATE") error("WHERE clause cannot be empty for UPDATE")
end end
local sets = {} local keys = tbl.keys(data)
local values = {} local values = tbl.values(data)
local sets = tbl.map(keys, function(key)
for key, value in pairs(data) do return str.template("${key} = ?", {key = key})
table.insert(sets, str.template("${key} = ?", {key = key})) end)
table.insert(values, value)
end
local query = str.template("UPDATE ${table} SET ${sets} WHERE ${where}", { local query = str.template("UPDATE ${table} SET ${sets} WHERE ${where}", {
table = table_name, table = table_name,
sets = str.join(sets, ", "), sets = tbl.concat(sets, ", "),
where = where_clause where = where_clause
}) })
-- Add WHERE clause parameters -- Add WHERE clause parameters
local where_args = {...} local where_args = {...}
for i = 1, #where_args do tbl.extend(values, where_args)
table.insert(values, where_args[i])
end
return self:exec(query, unpack(values)) return self:exec(query, unpack(values))
end end
@ -227,7 +213,7 @@ function Connection:select(table_name, columns, where_clause, ...)
columns = columns or "*" columns = columns or "*"
if type(columns) == "table" then if type(columns) == "table" then
columns = str.join(columns, ", ") columns = tbl.concat(columns, ", ")
end end
local query local query
@ -247,7 +233,7 @@ function Connection:select(table_name, columns, where_clause, ...)
end end
end end
-- Enhanced schema helpers with validation -- Schema helpers
function Connection:table_exists(table_name) function Connection:table_exists(table_name)
if str.is_blank(table_name) then if str.is_blank(table_name) then
return false return false
@ -267,11 +253,9 @@ function Connection:column_exists(table_name, column_name)
local result = self:query(str.template("PRAGMA table_info(${table})", {table = table_name})) local result = self:query(str.template("PRAGMA table_info(${table})", {table = table_name}))
if result then if result then
for _, row in ipairs(result) do return tbl.any(result, function(row)
if str.iequals(row.name, str.trim(column_name)) then return str.iequals(row.name, str.trim(column_name))
return true end)
end
end
end end
return false return false
end end
@ -315,7 +299,7 @@ function Connection:create_index(index_name, table_name, columns, unique)
end end
local unique_clause = unique and "UNIQUE " or "" local unique_clause = unique and "UNIQUE " or ""
local columns_str = type(columns) == "table" and str.join(columns, ", ") or tostring(columns) local columns_str = type(columns) == "table" and tbl.concat(columns, ", ") or tostring(columns)
local query = str.template("CREATE ${unique}INDEX IF NOT EXISTS ${index} ON ${table} (${columns})", { local query = str.template("CREATE ${unique}INDEX IF NOT EXISTS ${index} ON ${table} (${columns})", {
unique = unique_clause, unique = unique_clause,
@ -335,7 +319,7 @@ function Connection:drop_index(index_name)
return self:exec(query) return self:exec(query)
end end
-- Enhanced SQLite-specific functions -- SQLite-specific functions
function Connection:vacuum() function Connection:vacuum()
return self:exec("VACUUM") return self:exec("VACUUM")
end end
@ -355,35 +339,24 @@ end
function Connection:journal_mode(mode) function Connection:journal_mode(mode)
mode = mode or "WAL" mode = mode or "WAL"
if not str.contains(str.upper(mode), "DELETE") and local valid_modes = {"DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"}
not str.contains(str.upper(mode), "TRUNCATE") and
not str.contains(str.upper(mode), "PERSIST") and if not tbl.contains(tbl.map(valid_modes, str.upper), str.upper(mode)) then
not str.contains(str.upper(mode), "MEMORY") and
not str.contains(str.upper(mode), "WAL") and
not str.contains(str.upper(mode), "OFF") then
error("Invalid journal mode: " .. mode) error("Invalid journal mode: " .. mode)
end end
return self:query(str.template("PRAGMA journal_mode = ${mode}", {mode = str.upper(mode)})) return self:query(str.template("PRAGMA journal_mode = ${mode}", {mode = str.upper(mode)}))
end end
function Connection:synchronous(level) function Connection:synchronous(level)
level = level or "NORMAL" level = level or "NORMAL"
local valid_levels = {"OFF", "NORMAL", "FULL", "EXTRA"} local valid_levels = {"OFF", "NORMAL", "FULL", "EXTRA"}
local level_upper = str.upper(level)
local valid = false if not tbl.contains(valid_levels, str.upper(level)) then
for _, valid_level in ipairs(valid_levels) do
if level_upper == valid_level then
valid = true
break
end
end
if not valid then
error("Invalid synchronous level: " .. level) error("Invalid synchronous level: " .. level)
end end
return self:exec(str.template("PRAGMA synchronous = ${level}", {level = level_upper})) return self:exec(str.template("PRAGMA synchronous = ${level}", {level = str.upper(level)}))
end end
function Connection:cache_size(size) function Connection:cache_size(size)
@ -397,28 +370,18 @@ end
function Connection:temp_store(mode) function Connection:temp_store(mode)
mode = mode or "MEMORY" mode = mode or "MEMORY"
local valid_modes = {"DEFAULT", "FILE", "MEMORY"} local valid_modes = {"DEFAULT", "FILE", "MEMORY"}
local mode_upper = str.upper(mode)
local valid = false if not tbl.contains(valid_modes, str.upper(mode)) then
for _, valid_mode in ipairs(valid_modes) do
if mode_upper == valid_mode then
valid = true
break
end
end
if not valid then
error("Invalid temp_store mode: " .. mode) error("Invalid temp_store mode: " .. mode)
end end
return self:exec(str.template("PRAGMA temp_store = ${mode}", {mode = mode_upper})) return self:exec(str.template("PRAGMA temp_store = ${mode}", {mode = str.upper(mode)}))
end end
-- Connection management with enhanced path handling -- Connection management
function sqlite.open(database_path) function sqlite.open(database_path)
database_path = database_path or ":memory:" database_path = database_path or ":memory:"
-- Clean up path
if database_path ~= ":memory:" then if database_path ~= ":memory:" then
database_path = str.trim(database_path) database_path = str.trim(database_path)
if str.is_blank(database_path) then if str.is_blank(database_path) then
@ -437,7 +400,7 @@ end
sqlite.connect = sqlite.open sqlite.connect = sqlite.open
-- Enhanced quick execution functions -- Quick execution functions
function sqlite.query(database_path, query_str, ...) function sqlite.query(database_path, query_str, ...)
local conn = sqlite.open(database_path) local conn = sqlite.open(database_path)
if not conn then if not conn then
@ -482,14 +445,13 @@ function sqlite.query_value(database_path, query_str, ...)
return nil return nil
end end
-- Enhanced migration helpers -- Migration helpers
function sqlite.migrate(database_path, migrations) function sqlite.migrate(database_path, migrations)
local conn = sqlite.open(database_path) local conn = sqlite.open(database_path)
if not conn then if not conn then
error("Failed to open SQLite database for migration") error("Failed to open SQLite database for migration")
end end
-- Create migrations table
conn:create_table("_migrations", conn:create_table("_migrations",
"id INTEGER PRIMARY KEY, name TEXT UNIQUE, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP") "id INTEGER PRIMARY KEY, name TEXT UNIQUE, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP")
@ -509,7 +471,6 @@ function sqlite.migrate(database_path, migrations)
break break
end end
-- Check if migration already applied
local existing = conn:query_value("SELECT id FROM _migrations WHERE name = ?", local existing = conn:query_value("SELECT id FROM _migrations WHERE name = ?",
str.trim(migration.name)) str.trim(migration.name))
if not existing then if not existing then
@ -549,9 +510,9 @@ function sqlite.migrate(database_path, migrations)
return true return true
end end
-- Enhanced result processing utilities -- Simplified result processing using table utilities
function sqlite.to_array(results, column_name) function sqlite.to_array(results, column_name)
if not results or #results == 0 then if not results or tbl.is_empty(results) then
return {} return {}
end end
@ -559,15 +520,11 @@ function sqlite.to_array(results, column_name)
error("Column name cannot be empty") error("Column name cannot be empty")
end end
local array = {} return tbl.map(results, function(row) return row[column_name] end)
for i, row in ipairs(results) do
array[i] = row[column_name]
end
return array
end end
function sqlite.to_map(results, key_column, value_column) function sqlite.to_map(results, key_column, value_column)
if not results or #results == 0 then if not results or tbl.is_empty(results) then
return {} return {}
end end
@ -584,7 +541,7 @@ function sqlite.to_map(results, key_column, value_column)
end end
function sqlite.group_by(results, column_name) function sqlite.group_by(results, column_name)
if not results or #results == 0 then if not results or tbl.is_empty(results) then
return {} return {}
end end
@ -592,36 +549,21 @@ function sqlite.group_by(results, column_name)
error("Column name cannot be empty") error("Column name cannot be empty")
end end
local groups = {} return tbl.group_by(results, function(row) return row[column_name] end)
for _, row in ipairs(results) do
local key = row[column_name]
if not groups[key] then
groups[key] = {}
end
table.insert(groups[key], row)
end
return groups
end end
-- Enhanced debug helper -- Simplified debug helper
function sqlite.print_results(results) function sqlite.print_results(results)
if not results or #results == 0 then if not results or tbl.is_empty(results) then
print("No results") print("No results")
return return
end end
-- Get column names from first row local columns = tbl.keys(results[1])
local columns = {} tbl.sort(columns)
for col, _ in pairs(results[1]) do
table.insert(columns, col)
end
table.sort(columns)
-- Calculate column widths for better formatting -- Calculate column widths
local widths = {} local widths = tbl.map_values(tbl.to_map(columns, function(col) return col end, function(col) return str.length(col) end), function(width) return width end)
for _, col in ipairs(columns) do
widths[col] = str.length(col)
end
for _, row in ipairs(results) do for _, row in ipairs(results) do
for _, col in ipairs(columns) do for _, col in ipairs(columns) do
@ -630,25 +572,20 @@ function sqlite.print_results(results)
end end
end end
-- Print header with proper spacing -- Print header
local header_parts = {} local header_parts = tbl.map(columns, function(col) return str.pad_right(col, widths[col]) end)
local separator_parts = {} local separator_parts = tbl.map(columns, function(col) return str.repeat_("-", widths[col]) end)
for _, col in ipairs(columns) do
table.insert(header_parts, str.pad_right(col, widths[col]))
table.insert(separator_parts, str.repeat_("-", widths[col]))
end
print(str.join(header_parts, " | ")) print(tbl.concat(header_parts, " | "))
print(str.join(separator_parts, "-+-")) print(tbl.concat(separator_parts, "-+-"))
-- Print rows with proper spacing -- Print rows
for _, row in ipairs(results) do for _, row in ipairs(results) do
local value_parts = {} local value_parts = tbl.map(columns, function(col)
for _, col in ipairs(columns) do
local value = tostring(row[col] or "") local value = tostring(row[col] or "")
table.insert(value_parts, str.pad_right(value, widths[col])) return str.pad_right(value, widths[col])
end end)
print(str.join(value_parts, " | ")) print(tbl.concat(value_parts, " | "))
end end
end end