enhance database modules with table utils

This commit is contained in:
Sky Johnson 2025-07-24 10:21:07 -05:00
parent 5551f16bc1
commit 71633b4b4b
3 changed files with 195 additions and 349 deletions

View File

@ -1,4 +1,5 @@
local str = require("string")
local tbl = require("table")
local mysql = {}
local Connection = {}
@ -132,26 +133,20 @@ function Connection:begin()
return nil
end
-- Enhanced MySQL-specific query builder helpers
-- Simplified MySQL-specific query builder helpers
function Connection:insert(table_name, data)
if str.is_blank(table_name) then
error("Table name cannot be empty")
end
local keys = {}
local values = {}
local placeholders = {}
for key, value in pairs(data) do
table.insert(keys, key)
table.insert(values, value)
table.insert(placeholders, "?")
end
local keys = tbl.keys(data)
local values = tbl.values(data)
local placeholders = tbl.map(keys, function() return "?" end)
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", {
table = table_name,
columns = str.join(keys, ", "),
placeholders = str.join(placeholders, ", ")
columns = tbl.concat(keys, ", "),
placeholders = tbl.concat(placeholders, ", ")
})
return self:exec(query, unpack(values))
@ -162,28 +157,21 @@ function Connection:upsert(table_name, data, update_data)
error("Table name cannot be empty")
end
local keys = {}
local values = {}
local placeholders = {}
local updates = {}
for key, value in pairs(data) do
table.insert(keys, key)
table.insert(values, value)
table.insert(placeholders, "?")
end
local keys = tbl.keys(data)
local values = tbl.values(data)
local placeholders = tbl.map(keys, function() return "?" end)
-- Use update_data if provided, otherwise update with same data
local update_source = update_data or data
for key, _ in pairs(update_source) do
table.insert(updates, str.template("${key} = VALUES(${key})", {key = key}))
end
local updates = tbl.map(tbl.keys(update_source), function(key)
return str.template("${key} = VALUES(${key})", {key = key})
end)
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON DUPLICATE KEY UPDATE ${updates}", {
table = table_name,
columns = str.join(keys, ", "),
placeholders = str.join(placeholders, ", "),
updates = str.join(updates, ", ")
columns = tbl.concat(keys, ", "),
placeholders = tbl.concat(placeholders, ", "),
updates = tbl.concat(updates, ", ")
})
return self:exec(query, unpack(values))
@ -194,20 +182,14 @@ function Connection:replace(table_name, data)
error("Table name cannot be empty")
end
local keys = {}
local values = {}
local placeholders = {}
for key, value in pairs(data) do
table.insert(keys, key)
table.insert(values, value)
table.insert(placeholders, "?")
end
local keys = tbl.keys(data)
local values = tbl.values(data)
local placeholders = tbl.map(keys, function() return "?" end)
local query = str.template("REPLACE INTO ${table} (${columns}) VALUES (${placeholders})", {
table = table_name,
columns = str.join(keys, ", "),
placeholders = str.join(placeholders, ", ")
columns = tbl.concat(keys, ", "),
placeholders = tbl.concat(placeholders, ", ")
})
return self:exec(query, unpack(values))
@ -221,25 +203,21 @@ function Connection:update(table_name, data, where_clause, ...)
error("WHERE clause cannot be empty for UPDATE")
end
local sets = {}
local values = {}
for key, value in pairs(data) do
table.insert(sets, str.template("${key} = ?", {key = key}))
table.insert(values, value)
end
local keys = tbl.keys(data)
local values = tbl.values(data)
local sets = tbl.map(keys, function(key)
return str.template("${key} = ?", {key = key})
end)
local query = str.template("UPDATE ${table} SET ${sets} WHERE ${where}", {
table = table_name,
sets = str.join(sets, ", "),
sets = tbl.concat(sets, ", "),
where = where_clause
})
-- Add WHERE clause parameters
local where_args = {...}
for i = 1, #where_args do
table.insert(values, where_args[i])
end
tbl.extend(values, where_args)
return self:exec(query, unpack(values))
end
@ -266,7 +244,7 @@ function Connection:select(table_name, columns, where_clause, ...)
columns = columns or "*"
if type(columns) == "table" then
columns = str.join(columns, ", ")
columns = tbl.concat(columns, ", ")
end
local query
@ -286,7 +264,7 @@ function Connection:select(table_name, columns, where_clause, ...)
end
end
-- Enhanced MySQL schema helpers
-- MySQL schema helpers
function Connection:database_exists(database_name)
if str.is_blank(database_name) then
return false
@ -441,7 +419,7 @@ function Connection:create_index(index_name, table_name, columns, unique, type)
local unique_clause = unique and "UNIQUE " or ""
local type_clause = type and str.template(" USING ${type}", {type = str.upper(type)}) or ""
local columns_str = type(columns) == "table" and str.join(columns, ", ") or tostring(columns)
local columns_str = type(columns) == "table" and tbl.concat(columns, ", ") or tostring(columns)
local query = str.template("CREATE ${unique}INDEX ${index} ON ${table} (${columns})${type}", {
unique = unique_clause,
@ -465,7 +443,7 @@ function Connection:drop_index(index_name, table_name)
return self:exec(query)
end
-- Enhanced MySQL maintenance functions
-- MySQL maintenance functions
function Connection:optimize(table_name)
local table_clause = table_name and str.template(" ${table}", {table = table_name}) or ""
return self:query(str.template("OPTIMIZE TABLE${table}", {table = table_clause}))
@ -488,15 +466,7 @@ function Connection:check_table(table_name, options)
local valid_options = {"QUICK", "FAST", "MEDIUM", "EXTENDED", "CHANGED"}
local options_upper = str.upper(options)
local valid = false
for _, valid_option in ipairs(valid_options) do
if options_upper == valid_option then
valid = true
break
end
end
if valid then
if tbl.contains(valid_options, options_upper) then
options_clause = str.template(" ${options}", {options = options_upper})
end
end
@ -514,7 +484,7 @@ function Connection:analyze_table(table_name)
return self:query(str.template("ANALYZE TABLE ${table}", {table = table_name}))
end
-- Enhanced MySQL settings and introspection
-- MySQL settings and introspection
function Connection:show(what)
if str.is_blank(what) then
error("SHOW parameter cannot be empty")
@ -575,7 +545,7 @@ function Connection:show_table_status(table_name)
end
end
-- Enhanced MySQL user and privilege management
-- MySQL user and privilege management
function Connection:create_user(username, password, host)
if str.is_blank(username) or str.is_blank(password) then
error("Username and password cannot be empty")
@ -642,7 +612,7 @@ function Connection:flush_privileges()
return self:exec("FLUSH PRIVILEGES")
end
-- Enhanced MySQL variables and configuration
-- MySQL variables and configuration
function Connection:set_variable(name, value, global)
if str.is_blank(name) then
error("Variable name cannot be empty")
@ -683,7 +653,7 @@ function Connection:show_status(pattern)
end
end
-- Enhanced connection management
-- Connection management
function mysql.connect(dsn)
if str.is_blank(dsn) then
error("DSN cannot be empty")
@ -700,7 +670,7 @@ end
mysql.open = mysql.connect
-- Enhanced quick execution functions
-- Quick execution functions
function mysql.query(dsn, query_str, ...)
local conn = mysql.connect(dsn)
if not conn then
@ -741,7 +711,7 @@ function mysql.query_value(dsn, query_str, ...)
return nil
end
-- Enhanced migration helpers
-- Migration helpers
function mysql.migrate(dsn, migrations, database_name)
local conn = mysql.connect(dsn)
if not conn then
@ -813,9 +783,9 @@ function mysql.migrate(dsn, migrations, database_name)
return true
end
-- Result processing utilities (same enhanced versions)
-- Simplified result processing utilities
function mysql.to_array(results, column_name)
if not results or #results == 0 then
if not results or tbl.is_empty(results) then
return {}
end
@ -823,15 +793,11 @@ function mysql.to_array(results, column_name)
error("Column name cannot be empty")
end
local array = {}
for i, row in ipairs(results) do
array[i] = row[column_name]
end
return array
return tbl.map(results, function(row) return row[column_name] end)
end
function mysql.to_map(results, key_column, value_column)
if not results or #results == 0 then
if not results or tbl.is_empty(results) then
return {}
end
@ -848,7 +814,7 @@ function mysql.to_map(results, key_column, value_column)
end
function mysql.group_by(results, column_name)
if not results or #results == 0 then
if not results or tbl.is_empty(results) then
return {}
end
@ -856,32 +822,20 @@ function mysql.group_by(results, column_name)
error("Column name cannot be empty")
end
local groups = {}
for _, row in ipairs(results) do
local key = row[column_name]
if not groups[key] then
groups[key] = {}
end
table.insert(groups[key], row)
end
return groups
return tbl.group_by(results, function(row) return row[column_name] end)
end
-- Enhanced debug helper (same as others)
-- Simplified debug helper
function mysql.print_results(results)
if not results or #results == 0 then
if not results or tbl.is_empty(results) then
print("No results")
return
end
-- Get column names from first row
local columns = {}
for col, _ in pairs(results[1]) do
table.insert(columns, col)
end
table.sort(columns)
local columns = tbl.keys(results[1])
tbl.sort(columns)
-- Calculate column widths for better formatting
-- Calculate column widths
local widths = {}
for _, col in ipairs(columns) do
widths[col] = str.length(col)
@ -894,34 +848,28 @@ function mysql.print_results(results)
end
end
-- Print header with proper spacing
local header_parts = {}
local separator_parts = {}
for _, col in ipairs(columns) do
table.insert(header_parts, str.pad_right(col, widths[col]))
table.insert(separator_parts, str.repeat_("-", widths[col]))
end
-- Print header
local header_parts = tbl.map(columns, function(col) return str.pad_right(col, widths[col]) end)
local separator_parts = tbl.map(columns, function(col) return str.repeat_("-", widths[col]) end)
print(str.join(header_parts, " | "))
print(str.join(separator_parts, "-+-"))
print(tbl.concat(header_parts, " | "))
print(tbl.concat(separator_parts, "-+-"))
-- Print rows with proper spacing
-- Print rows
for _, row in ipairs(results) do
local value_parts = {}
for _, col in ipairs(columns) do
local value_parts = tbl.map(columns, function(col)
local value = tostring(row[col] or "")
table.insert(value_parts, str.pad_right(value, widths[col]))
end
print(str.join(value_parts, " | "))
return str.pad_right(value, widths[col])
end)
print(tbl.concat(value_parts, " | "))
end
end
-- Enhanced MySQL-specific utilities
-- MySQL-specific utilities
function mysql.escape_string(str_val)
if type(str_val) ~= "string" then
return tostring(str_val)
end
-- Basic escaping - in production, use proper escaping functions
return str.replace(str_val, "'", "\\'")
end
@ -932,7 +880,7 @@ function mysql.escape_identifier(name)
return str.template("`${name}`", {name = str.replace(name, "`", "``")})
end
-- Enhanced DSN builder helper
-- DSN builder helper
function mysql.build_dsn(options)
if type(options) ~= "table" then
error("Options must be a table")
@ -941,7 +889,7 @@ function mysql.build_dsn(options)
local parts = {}
if options.username and not str.is_blank(options.username) then
table.insert(parts, options.username)
tbl.insert(parts, options.username)
if options.password and not str.is_blank(options.password) then
parts[#parts] = str.template("${user}:${pass}", {
user = parts[#parts],
@ -952,9 +900,9 @@ function mysql.build_dsn(options)
end
if options.protocol and not str.is_blank(options.protocol) then
table.insert(parts, str.template("${protocol}(", {protocol = options.protocol}))
tbl.insert(parts, str.template("${protocol}(", {protocol = options.protocol}))
if options.host and not str.is_blank(options.host) then
table.insert(parts, options.host)
tbl.insert(parts, options.host)
if options.port then
parts[#parts] = str.template("${host}:${port}", {
host = parts[#parts],
@ -971,33 +919,33 @@ function mysql.build_dsn(options)
port = tostring(options.port)
})
end
table.insert(parts, host_part .. ")")
tbl.insert(parts, host_part .. ")")
end
if options.database and not str.is_blank(options.database) then
table.insert(parts, str.template("/${database}", {database = options.database}))
tbl.insert(parts, str.template("/${database}", {database = options.database}))
end
-- Add parameters
local params = {}
if options.charset and not str.is_blank(options.charset) then
table.insert(params, str.template("charset=${charset}", {charset = options.charset}))
tbl.insert(params, str.template("charset=${charset}", {charset = options.charset}))
end
if options.parseTime ~= nil then
table.insert(params, str.template("parseTime=${parse}", {parse = tostring(options.parseTime)}))
tbl.insert(params, str.template("parseTime=${parse}", {parse = tostring(options.parseTime)}))
end
if options.timeout and not str.is_blank(options.timeout) then
table.insert(params, str.template("timeout=${timeout}", {timeout = options.timeout}))
tbl.insert(params, str.template("timeout=${timeout}", {timeout = options.timeout}))
end
if options.tls and not str.is_blank(options.tls) then
table.insert(params, str.template("tls=${tls}", {tls = options.tls}))
tbl.insert(params, str.template("tls=${tls}", {tls = options.tls}))
end
if #params > 0 then
table.insert(parts, str.template("?${params}", {params = str.join(params, "&")}))
tbl.insert(parts, str.template("?${params}", {params = tbl.concat(params, "&")}))
end
return str.join(parts, "")
return tbl.concat(parts, "")
end
return mysql

View File

@ -1,4 +1,5 @@
local str = require("string")
local tbl = require("table")
local postgres = {}
local Connection = {}
@ -132,23 +133,20 @@ function Connection:begin()
return nil
end
-- Enhanced query builder helpers with PostgreSQL parameter numbering
-- Simplified PostgreSQL parameter builder
local function build_postgres_params(data)
local keys = {}
local values = {}
local keys = tbl.keys(data)
local values = tbl.values(data)
local placeholders = {}
local param_count = 0
for key, value in pairs(data) do
table.insert(keys, key)
table.insert(values, value)
param_count = param_count + 1
table.insert(placeholders, str.template("$${num}", {num = tostring(param_count)}))
for i = 1, #keys do
tbl.insert(placeholders, str.template("$${num}", {num = tostring(i)}))
end
return keys, values, placeholders, param_count
return keys, values, placeholders, #keys
end
-- Simplified query builders using table utilities
function Connection:insert(table_name, data, returning)
if str.is_blank(table_name) then
error("Table name cannot be empty")
@ -158,8 +156,8 @@ function Connection:insert(table_name, data, returning)
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", {
table = table_name,
columns = str.join(keys, ", "),
placeholders = str.join(placeholders, ", ")
columns = tbl.concat(keys, ", "),
placeholders = tbl.concat(placeholders, ", ")
})
if returning and not str.is_blank(returning) then
@ -179,27 +177,25 @@ function Connection:upsert(table_name, data, conflict_columns, returning)
end
local keys, values, placeholders = build_postgres_params(data)
local updates = {}
for _, key in ipairs(keys) do
table.insert(updates, str.template("${key} = EXCLUDED.${key}", {key = key}))
end
local updates = tbl.map(keys, function(key)
return str.template("${key} = EXCLUDED.${key}", {key = key})
end)
local conflict_clause = ""
if conflict_columns then
if type(conflict_columns) == "string" then
conflict_clause = str.template("(${columns})", {columns = conflict_columns})
else
conflict_clause = str.template("(${columns})", {columns = str.join(conflict_columns, ", ")})
conflict_clause = str.template("(${columns})", {columns = tbl.concat(conflict_columns, ", ")})
end
end
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON CONFLICT ${conflict} DO UPDATE SET ${updates}", {
table = table_name,
columns = str.join(keys, ", "),
placeholders = str.join(placeholders, ", "),
columns = tbl.concat(keys, ", "),
placeholders = tbl.concat(placeholders, ", "),
conflict = conflict_clause,
updates = str.join(updates, ", ")
updates = tbl.concat(updates, ", ")
})
if returning and not str.is_blank(returning) then
@ -221,17 +217,16 @@ function Connection:update(table_name, data, where_clause, returning, ...)
error("WHERE clause cannot be empty for UPDATE")
end
local sets = {}
local values = {}
local param_count = 0
local keys = tbl.keys(data)
local values = tbl.values(data)
local param_count = #keys
for key, value in pairs(data) do
param_count = param_count + 1
table.insert(sets, str.template("${key} = $${num}", {
local sets = {}
for i, key in ipairs(keys) do
tbl.insert(sets, str.template("${key} = $${num}", {
key = key,
num = tostring(param_count)
num = tostring(i)
}))
table.insert(values, value)
end
-- Handle WHERE clause parameters
@ -239,15 +234,14 @@ function Connection:update(table_name, data, where_clause, returning, ...)
local where_clause_with_params = where_clause
for i = 1, #where_args do
param_count = param_count + 1
table.insert(values, where_args[i])
-- Replace ? with numbered parameter if needed
tbl.insert(values, where_args[i])
where_clause_with_params = str.replace(where_clause_with_params, "?",
str.template("$${num}", {num = tostring(param_count)}), 1)
end
local query = str.template("UPDATE ${table} SET ${sets} WHERE ${where}", {
table = table_name,
sets = str.join(sets, ", "),
sets = tbl.concat(sets, ", "),
where = where_clause_with_params
})
@ -275,7 +269,7 @@ function Connection:delete(table_name, where_clause, returning, ...)
local values = {}
local where_clause_with_params = where_clause
for i = 1, #where_args do
table.insert(values, where_args[i])
tbl.insert(values, where_args[i])
where_clause_with_params = str.replace(where_clause_with_params, "?",
str.template("$${num}", {num = tostring(i)}), 1)
end
@ -303,7 +297,7 @@ function Connection:select(table_name, columns, where_clause, ...)
columns = columns or "*"
if type(columns) == "table" then
columns = str.join(columns, ", ")
columns = tbl.concat(columns, ", ")
end
local query
@ -313,7 +307,7 @@ function Connection:select(table_name, columns, where_clause, ...)
local values = {}
local where_clause_with_params = where_clause
for i = 1, #where_args do
table.insert(values, where_args[i])
tbl.insert(values, where_args[i])
where_clause_with_params = str.replace(where_clause_with_params, "?",
str.template("$${num}", {num = tostring(i)}), 1)
end
@ -418,7 +412,7 @@ function Connection:create_index(index_name, table_name, columns, unique, method
local unique_clause = unique and "UNIQUE " or ""
local method_clause = method and str.template(" USING ${method}", {method = str.upper(method)}) or ""
local columns_str = type(columns) == "table" and str.join(columns, ", ") or tostring(columns)
local columns_str = type(columns) == "table" and tbl.concat(columns, ", ") or tostring(columns)
local query = str.template("CREATE ${unique}INDEX IF NOT EXISTS ${index} ON ${table}${method} (${columns})", {
unique = unique_clause,
@ -443,7 +437,7 @@ function Connection:drop_index(index_name, cascade)
return self:exec(query)
end
-- Enhanced PostgreSQL-specific functions
-- PostgreSQL-specific functions
function Connection:vacuum(table_name, analyze)
local analyze_clause = analyze and " ANALYZE" or ""
local table_clause = table_name and str.template(" ${table}", {table = table_name}) or ""
@ -467,15 +461,7 @@ function Connection:reindex(name, type)
local valid_types = {"INDEX", "TABLE", "SCHEMA", "DATABASE", "SYSTEM"}
local type_upper = str.upper(type)
local valid = false
for _, valid_type in ipairs(valid_types) do
if type_upper == valid_type then
valid = true
break
end
end
if not valid then
if not tbl.contains(valid_types, type_upper) then
error(str.template("Invalid REINDEX type: ${type}", {type = type}))
end
@ -539,7 +525,7 @@ function Connection:describe_table(table_name, schema_name)
]], str.trim(schema_name), str.trim(table_name))
end
-- Enhanced JSON/JSONB helpers
-- JSON/JSONB helpers
function Connection:json_extract(column, path)
if str.is_blank(column) or str.is_blank(path) then
error("Column and path cannot be empty")
@ -568,7 +554,7 @@ function Connection:jsonb_contained_by(column, value)
return str.template("${column} <@ '${value}'", {column = column, value = value})
end
-- Enhanced Array helpers
-- Array helpers
function Connection:array_contains(column, value)
if str.is_blank(column) then
error("Column cannot be empty")
@ -583,15 +569,13 @@ function Connection:array_length(column)
return str.template("array_length(${column}, 1)", {column = column})
end
-- Enhanced connection management with DSN parsing
-- Connection management
function postgres.parse_dsn(dsn)
if str.is_blank(dsn) then
return nil, "DSN cannot be empty"
end
local parts = {}
-- Split by spaces and handle key=value pairs
for pair in str.trim(dsn):gmatch("[^%s]+") do
local key, value = pair:match("([^=]+)=(.+)")
if key and value then
@ -618,7 +602,7 @@ end
postgres.open = postgres.connect
-- Enhanced quick execution functions
-- Quick execution functions
function postgres.query(dsn, query_str, ...)
local conn = postgres.connect(dsn)
if not conn then
@ -659,7 +643,7 @@ function postgres.query_value(dsn, query_str, ...)
return nil
end
-- Enhanced migration helpers
-- Migration helpers
function postgres.migrate(dsn, migrations, schema)
schema = schema or "public"
local conn = postgres.connect(dsn)
@ -667,7 +651,6 @@ function postgres.migrate(dsn, migrations, schema)
error("Failed to connect to PostgreSQL database for migration")
end
-- Create migrations table
conn:create_table("_migrations",
"id SERIAL PRIMARY KEY, name TEXT UNIQUE NOT NULL, applied_at TIMESTAMPTZ DEFAULT NOW()")
@ -687,7 +670,6 @@ function postgres.migrate(dsn, migrations, schema)
break
end
-- Check if migration already applied
local existing = conn:query_value("SELECT id FROM _migrations WHERE name = $1",
str.trim(migration.name))
if not existing then
@ -727,9 +709,9 @@ function postgres.migrate(dsn, migrations, schema)
return true
end
-- Result processing utilities (same enhanced versions as SQLite)
-- Simplified result processing utilities
function postgres.to_array(results, column_name)
if not results or #results == 0 then
if not results or tbl.is_empty(results) then
return {}
end
@ -737,15 +719,11 @@ function postgres.to_array(results, column_name)
error("Column name cannot be empty")
end
local array = {}
for i, row in ipairs(results) do
array[i] = row[column_name]
end
return array
return tbl.map(results, function(row) return row[column_name] end)
end
function postgres.to_map(results, key_column, value_column)
if not results or #results == 0 then
if not results or tbl.is_empty(results) then
return {}
end
@ -762,7 +740,7 @@ function postgres.to_map(results, key_column, value_column)
end
function postgres.group_by(results, column_name)
if not results or #results == 0 then
if not results or tbl.is_empty(results) then
return {}
end
@ -770,32 +748,20 @@ function postgres.group_by(results, column_name)
error("Column name cannot be empty")
end
local groups = {}
for _, row in ipairs(results) do
local key = row[column_name]
if not groups[key] then
groups[key] = {}
end
table.insert(groups[key], row)
end
return groups
return tbl.group_by(results, function(row) return row[column_name] end)
end
-- Enhanced debug helper (same as SQLite)
-- Simplified debug helper
function postgres.print_results(results)
if not results or #results == 0 then
if not results or tbl.is_empty(results) then
print("No results")
return
end
-- Get column names from first row
local columns = {}
for col, _ in pairs(results[1]) do
table.insert(columns, col)
end
table.sort(columns)
local columns = tbl.keys(results[1])
tbl.sort(columns)
-- Calculate column widths for better formatting
-- Calculate column widths
local widths = {}
for _, col in ipairs(columns) do
widths[col] = str.length(col)
@ -808,25 +774,20 @@ function postgres.print_results(results)
end
end
-- Print header with proper spacing
local header_parts = {}
local separator_parts = {}
for _, col in ipairs(columns) do
table.insert(header_parts, str.pad_right(col, widths[col]))
table.insert(separator_parts, str.repeat_("-", widths[col]))
end
-- Print header
local header_parts = tbl.map(columns, function(col) return str.pad_right(col, widths[col]) end)
local separator_parts = tbl.map(columns, function(col) return str.repeat_("-", widths[col]) end)
print(str.join(header_parts, " | "))
print(str.join(separator_parts, "-+-"))
print(tbl.concat(header_parts, " | "))
print(tbl.concat(separator_parts, "-+-"))
-- Print rows with proper spacing
-- Print rows
for _, row in ipairs(results) do
local value_parts = {}
for _, col in ipairs(columns) do
local value_parts = tbl.map(columns, function(col)
local value = tostring(row[col] or "")
table.insert(value_parts, str.pad_right(value, widths[col]))
end
print(str.join(value_parts, " | "))
return str.pad_right(value, widths[col])
end)
print(tbl.concat(value_parts, " | "))
end
end

View File

@ -1,4 +1,5 @@
local str = require("string")
local tbl = require("table")
local sqlite = {}
local Connection = {}
@ -112,26 +113,20 @@ function Connection:begin()
return nil
end
-- Enhanced query builder helpers with string utilities
-- Simplified query builders using table utilities
function Connection:insert(table_name, data)
if str.is_blank(table_name) then
error("Table name cannot be empty")
end
local keys = {}
local values = {}
local placeholders = {}
for key, value in pairs(data) do
table.insert(keys, key)
table.insert(values, value)
table.insert(placeholders, "?")
end
local keys = tbl.keys(data)
local values = tbl.values(data)
local placeholders = tbl.map(keys, function() return "?" end)
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders})", {
table = table_name,
columns = str.join(keys, ", "),
placeholders = str.join(placeholders, ", ")
columns = tbl.concat(keys, ", "),
placeholders = tbl.concat(placeholders, ", ")
})
return self:exec(query, unpack(values))
@ -142,33 +137,28 @@ function Connection:upsert(table_name, data, conflict_columns)
error("Table name cannot be empty")
end
local keys = {}
local values = {}
local placeholders = {}
local updates = {}
for key, value in pairs(data) do
table.insert(keys, key)
table.insert(values, value)
table.insert(placeholders, "?")
table.insert(updates, str.template("${key} = excluded.${key}", {key = key}))
end
local keys = tbl.keys(data)
local values = tbl.values(data)
local placeholders = tbl.map(keys, function() return "?" end)
local updates = tbl.map(keys, function(key)
return str.template("${key} = excluded.${key}", {key = key})
end)
local conflict_clause = ""
if conflict_columns then
if type(conflict_columns) == "string" then
conflict_clause = str.template("(${columns})", {columns = conflict_columns})
else
conflict_clause = str.template("(${columns})", {columns = str.join(conflict_columns, ", ")})
conflict_clause = str.template("(${columns})", {columns = tbl.concat(conflict_columns, ", ")})
end
end
local query = str.template("INSERT INTO ${table} (${columns}) VALUES (${placeholders}) ON CONFLICT ${conflict} DO UPDATE SET ${updates}", {
table = table_name,
columns = str.join(keys, ", "),
placeholders = str.join(placeholders, ", "),
columns = tbl.concat(keys, ", "),
placeholders = tbl.concat(placeholders, ", "),
conflict = conflict_clause,
updates = str.join(updates, ", ")
updates = tbl.concat(updates, ", ")
})
return self:exec(query, unpack(values))
@ -182,25 +172,21 @@ function Connection:update(table_name, data, where_clause, ...)
error("WHERE clause cannot be empty for UPDATE")
end
local sets = {}
local values = {}
for key, value in pairs(data) do
table.insert(sets, str.template("${key} = ?", {key = key}))
table.insert(values, value)
end
local keys = tbl.keys(data)
local values = tbl.values(data)
local sets = tbl.map(keys, function(key)
return str.template("${key} = ?", {key = key})
end)
local query = str.template("UPDATE ${table} SET ${sets} WHERE ${where}", {
table = table_name,
sets = str.join(sets, ", "),
sets = tbl.concat(sets, ", "),
where = where_clause
})
-- Add WHERE clause parameters
local where_args = {...}
for i = 1, #where_args do
table.insert(values, where_args[i])
end
tbl.extend(values, where_args)
return self:exec(query, unpack(values))
end
@ -227,7 +213,7 @@ function Connection:select(table_name, columns, where_clause, ...)
columns = columns or "*"
if type(columns) == "table" then
columns = str.join(columns, ", ")
columns = tbl.concat(columns, ", ")
end
local query
@ -247,7 +233,7 @@ function Connection:select(table_name, columns, where_clause, ...)
end
end
-- Enhanced schema helpers with validation
-- Schema helpers
function Connection:table_exists(table_name)
if str.is_blank(table_name) then
return false
@ -267,11 +253,9 @@ function Connection:column_exists(table_name, column_name)
local result = self:query(str.template("PRAGMA table_info(${table})", {table = table_name}))
if result then
for _, row in ipairs(result) do
if str.iequals(row.name, str.trim(column_name)) then
return true
end
end
return tbl.any(result, function(row)
return str.iequals(row.name, str.trim(column_name))
end)
end
return false
end
@ -315,7 +299,7 @@ function Connection:create_index(index_name, table_name, columns, unique)
end
local unique_clause = unique and "UNIQUE " or ""
local columns_str = type(columns) == "table" and str.join(columns, ", ") or tostring(columns)
local columns_str = type(columns) == "table" and tbl.concat(columns, ", ") or tostring(columns)
local query = str.template("CREATE ${unique}INDEX IF NOT EXISTS ${index} ON ${table} (${columns})", {
unique = unique_clause,
@ -335,7 +319,7 @@ function Connection:drop_index(index_name)
return self:exec(query)
end
-- Enhanced SQLite-specific functions
-- SQLite-specific functions
function Connection:vacuum()
return self:exec("VACUUM")
end
@ -355,35 +339,24 @@ end
function Connection:journal_mode(mode)
mode = mode or "WAL"
if not str.contains(str.upper(mode), "DELETE") and
not str.contains(str.upper(mode), "TRUNCATE") and
not str.contains(str.upper(mode), "PERSIST") and
not str.contains(str.upper(mode), "MEMORY") and
not str.contains(str.upper(mode), "WAL") and
not str.contains(str.upper(mode), "OFF") then
local valid_modes = {"DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"}
if not tbl.contains(tbl.map(valid_modes, str.upper), str.upper(mode)) then
error("Invalid journal mode: " .. mode)
end
return self:query(str.template("PRAGMA journal_mode = ${mode}", {mode = str.upper(mode)}))
end
function Connection:synchronous(level)
level = level or "NORMAL"
local valid_levels = {"OFF", "NORMAL", "FULL", "EXTRA"}
local level_upper = str.upper(level)
local valid = false
for _, valid_level in ipairs(valid_levels) do
if level_upper == valid_level then
valid = true
break
end
end
if not valid then
if not tbl.contains(valid_levels, str.upper(level)) then
error("Invalid synchronous level: " .. level)
end
return self:exec(str.template("PRAGMA synchronous = ${level}", {level = level_upper}))
return self:exec(str.template("PRAGMA synchronous = ${level}", {level = str.upper(level)}))
end
function Connection:cache_size(size)
@ -397,28 +370,18 @@ end
function Connection:temp_store(mode)
mode = mode or "MEMORY"
local valid_modes = {"DEFAULT", "FILE", "MEMORY"}
local mode_upper = str.upper(mode)
local valid = false
for _, valid_mode in ipairs(valid_modes) do
if mode_upper == valid_mode then
valid = true
break
end
end
if not valid then
if not tbl.contains(valid_modes, str.upper(mode)) then
error("Invalid temp_store mode: " .. mode)
end
return self:exec(str.template("PRAGMA temp_store = ${mode}", {mode = mode_upper}))
return self:exec(str.template("PRAGMA temp_store = ${mode}", {mode = str.upper(mode)}))
end
-- Connection management with enhanced path handling
-- Connection management
function sqlite.open(database_path)
database_path = database_path or ":memory:"
-- Clean up path
if database_path ~= ":memory:" then
database_path = str.trim(database_path)
if str.is_blank(database_path) then
@ -437,7 +400,7 @@ end
sqlite.connect = sqlite.open
-- Enhanced quick execution functions
-- Quick execution functions
function sqlite.query(database_path, query_str, ...)
local conn = sqlite.open(database_path)
if not conn then
@ -482,14 +445,13 @@ function sqlite.query_value(database_path, query_str, ...)
return nil
end
-- Enhanced migration helpers
-- Migration helpers
function sqlite.migrate(database_path, migrations)
local conn = sqlite.open(database_path)
if not conn then
error("Failed to open SQLite database for migration")
end
-- Create migrations table
conn:create_table("_migrations",
"id INTEGER PRIMARY KEY, name TEXT UNIQUE, applied_at DATETIME DEFAULT CURRENT_TIMESTAMP")
@ -509,7 +471,6 @@ function sqlite.migrate(database_path, migrations)
break
end
-- Check if migration already applied
local existing = conn:query_value("SELECT id FROM _migrations WHERE name = ?",
str.trim(migration.name))
if not existing then
@ -549,9 +510,9 @@ function sqlite.migrate(database_path, migrations)
return true
end
-- Enhanced result processing utilities
-- Simplified result processing using table utilities
function sqlite.to_array(results, column_name)
if not results or #results == 0 then
if not results or tbl.is_empty(results) then
return {}
end
@ -559,15 +520,11 @@ function sqlite.to_array(results, column_name)
error("Column name cannot be empty")
end
local array = {}
for i, row in ipairs(results) do
array[i] = row[column_name]
end
return array
return tbl.map(results, function(row) return row[column_name] end)
end
function sqlite.to_map(results, key_column, value_column)
if not results or #results == 0 then
if not results or tbl.is_empty(results) then
return {}
end
@ -584,7 +541,7 @@ function sqlite.to_map(results, key_column, value_column)
end
function sqlite.group_by(results, column_name)
if not results or #results == 0 then
if not results or tbl.is_empty(results) then
return {}
end
@ -592,36 +549,21 @@ function sqlite.group_by(results, column_name)
error("Column name cannot be empty")
end
local groups = {}
for _, row in ipairs(results) do
local key = row[column_name]
if not groups[key] then
groups[key] = {}
end
table.insert(groups[key], row)
end
return groups
return tbl.group_by(results, function(row) return row[column_name] end)
end
-- Enhanced debug helper
-- Simplified debug helper
function sqlite.print_results(results)
if not results or #results == 0 then
if not results or tbl.is_empty(results) then
print("No results")
return
end
-- Get column names from first row
local columns = {}
for col, _ in pairs(results[1]) do
table.insert(columns, col)
end
table.sort(columns)
local columns = tbl.keys(results[1])
tbl.sort(columns)
-- Calculate column widths for better formatting
local widths = {}
for _, col in ipairs(columns) do
widths[col] = str.length(col)
end
-- Calculate column widths
local widths = tbl.map_values(tbl.to_map(columns, function(col) return col end, function(col) return str.length(col) end), function(width) return width end)
for _, row in ipairs(results) do
for _, col in ipairs(columns) do
@ -630,25 +572,20 @@ function sqlite.print_results(results)
end
end
-- Print header with proper spacing
local header_parts = {}
local separator_parts = {}
for _, col in ipairs(columns) do
table.insert(header_parts, str.pad_right(col, widths[col]))
table.insert(separator_parts, str.repeat_("-", widths[col]))
end
-- Print header
local header_parts = tbl.map(columns, function(col) return str.pad_right(col, widths[col]) end)
local separator_parts = tbl.map(columns, function(col) return str.repeat_("-", widths[col]) end)
print(str.join(header_parts, " | "))
print(str.join(separator_parts, "-+-"))
print(tbl.concat(header_parts, " | "))
print(tbl.concat(separator_parts, "-+-"))
-- Print rows with proper spacing
-- Print rows
for _, row in ipairs(results) do
local value_parts = {}
for _, col in ipairs(columns) do
local value_parts = tbl.map(columns, function(col)
local value = tostring(row[col] or "")
table.insert(value_parts, str.pad_right(value, widths[col]))
end
print(str.join(value_parts, " | "))
return str.pad_right(value, widths[col])
end)
print(tbl.concat(value_parts, " | "))
end
end